582 lines
20 KiB
Python
582 lines
20 KiB
Python
"""Memory maintenance jobs — Phase 3/5.
|
|
|
|
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
|
|
|
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
|
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
|
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
|
|
|
All are safe to call manually or from tests; they never raise.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
|
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Decay parameters for relations
|
|
_DECAY_FACTOR = 0.95
|
|
_DECAY_PERIOD_DAYS = 30
|
|
_PRUNE_THRESHOLD = 0.2
|
|
|
|
# Proactive pattern decay: 10 % per 7 days since last sighting
|
|
_PROACTIVE_DECAY_FACTOR = 0.9
|
|
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
|
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
|
|
|
# Mining: require at least this many episodes to attempt pattern extraction
|
|
_MIN_EPISODES_FOR_MINING = 3
|
|
_MINING_LOOKBACK_DAYS = 30
|
|
|
|
# Audit: caps to control token cost
|
|
_AUDIT_MAX_FACTS = 50
|
|
_AUDIT_MAX_LABELS = 100
|
|
|
|
|
|
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
|
"""Apply confidence decay to all relation rows for a user.
|
|
|
|
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
|
Rows whose confidence falls below 0.2 are deleted.
|
|
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _decay_relations_inner(db, user_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
|
|
|
|
|
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
|
result = await db.execute(
|
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
now = datetime.now(timezone.utc)
|
|
deleted = 0
|
|
decayed = 0
|
|
|
|
for row in rows:
|
|
reference = row.last_confirmed_at or row.created_at
|
|
if reference is None:
|
|
continue
|
|
if reference.tzinfo is None:
|
|
reference = reference.replace(tzinfo=timezone.utc)
|
|
|
|
days_elapsed = (now - reference).days
|
|
if days_elapsed < _DECAY_PERIOD_DAYS:
|
|
continue
|
|
|
|
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
|
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
|
|
|
if new_confidence < _PRUNE_THRESHOLD:
|
|
await db.delete(row)
|
|
deleted += 1
|
|
logger.info(
|
|
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
|
"confidence=%.3f (below threshold)",
|
|
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
|
)
|
|
else:
|
|
row.confidence = new_confidence
|
|
decayed += 1
|
|
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
|
user_id, decayed, deleted,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
|
await db.rollback()
|
|
|
|
|
|
async def drain_extraction_queue(db: AsyncSession) -> None:
|
|
"""Process pending ExtractionQueue rows for Free-tier users.
|
|
|
|
Each row corresponds to a stored episode that should be fed through the
|
|
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _drain_extraction_queue_inner(db)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
|
|
|
|
|
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
|
from app.models import ExtractionQueue # noqa: PLC0415
|
|
|
|
result = await db.execute(select(ExtractionQueue))
|
|
rows = result.scalars().all()
|
|
|
|
if not rows:
|
|
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
|
return
|
|
|
|
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
|
|
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
|
|
|
processed = 0
|
|
for row in rows:
|
|
try:
|
|
await run_extraction(
|
|
db=db,
|
|
user_id=row.user_id,
|
|
last_user_msg="",
|
|
last_assistant_msg="",
|
|
session_id=None,
|
|
)
|
|
await db.delete(row)
|
|
await db.commit()
|
|
processed += 1
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_maintenance: drain failed row=%s user=%s: %s",
|
|
row.id, row.user_id, exc,
|
|
)
|
|
await db.rollback()
|
|
|
|
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
|
|
|
|
|
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
|
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
|
|
|
Steps:
|
|
1. Gate on proactive_mining tier feature.
|
|
2. Load + decrypt last 30 days of episodic summaries.
|
|
3. Call gpt-4o-mini to identify recurring patterns.
|
|
4. Encrypt and store each pattern in memory_proactive.
|
|
5. Apply decay to existing proactive rows.
|
|
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _mine_proactive_patterns_inner(db, user_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
|
|
|
|
|
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
|
|
tier = await tier_manager.get_tier(user_id, db)
|
|
if not tier_manager.check_feature(tier, "proactive_mining"):
|
|
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
|
return
|
|
|
|
# Load user Fernet key
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
|
return
|
|
|
|
fernet = Fernet(user.encryption_key.encode())
|
|
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
|
|
|
episodes_result = await db.execute(
|
|
select(MemoryEpisodic)
|
|
.where(
|
|
MemoryEpisodic.user_id == user_id,
|
|
MemoryEpisodic.created_at >= cutoff,
|
|
)
|
|
.order_by(MemoryEpisodic.created_at.asc())
|
|
)
|
|
episode_rows = episodes_result.scalars().all()
|
|
|
|
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
|
logger.info(
|
|
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
|
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
|
)
|
|
return
|
|
|
|
summaries: list[str] = []
|
|
for ep in episode_rows:
|
|
try:
|
|
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
|
summaries.append(plaintext)
|
|
except Exception:
|
|
pass
|
|
|
|
if not summaries:
|
|
return
|
|
|
|
patterns = await _extract_proactive_patterns(summaries)
|
|
if not patterns:
|
|
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
|
return
|
|
|
|
stored = 0
|
|
for pattern_text in patterns:
|
|
try:
|
|
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
|
row = MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
pattern_encrypted=encrypted,
|
|
confidence=0.7,
|
|
source="inferred",
|
|
)
|
|
db.add(row)
|
|
stored += 1
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
|
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
|
user_id, stored,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
|
await db.rollback()
|
|
return
|
|
|
|
await _decay_proactive_patterns(db, user_id, fernet)
|
|
|
|
|
|
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
|
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
|
from app.core.llm import get_agent_llm # noqa: PLC0415
|
|
|
|
llm = get_agent_llm("memory-miner", temperature=0)
|
|
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
|
prompt = (
|
|
"You are analyzing conversation history for a personal AI secretary. "
|
|
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
|
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
|
"Return each pattern as a plain, short English sentence on its own line. "
|
|
"No numbering, no bullet points, no extra text.\n\n"
|
|
f"Conversation history:\n{combined}"
|
|
)
|
|
try:
|
|
response = await llm.ainvoke(prompt)
|
|
text = response.content if hasattr(response, "content") else str(response)
|
|
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
|
return lines[:5]
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
|
return []
|
|
|
|
|
|
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
|
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
|
result = await db.execute(
|
|
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
now = datetime.now(timezone.utc)
|
|
deleted = 0
|
|
decayed = 0
|
|
|
|
for row in rows:
|
|
reference = row.created_at
|
|
if reference is None:
|
|
continue
|
|
if reference.tzinfo is None:
|
|
reference = reference.replace(tzinfo=timezone.utc)
|
|
|
|
days_elapsed = (now - reference).days
|
|
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
|
continue
|
|
|
|
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
|
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
|
|
|
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
|
await db.delete(row)
|
|
deleted += 1
|
|
else:
|
|
row.confidence = new_confidence
|
|
decayed += 1
|
|
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
|
user_id, decayed, deleted,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
|
await db.rollback()
|
|
|
|
|
|
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
|
|
|
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
|
"You are auditing a personal AI assistant's memory bank. "
|
|
"Each fact has an ID in brackets. "
|
|
"Find pairs that directly contradict each other "
|
|
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
|
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
|
'Return ONLY a valid JSON array, no markdown fences: '
|
|
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
|
"If no contradictions, return [].\n\n"
|
|
"Facts:\n{facts}"
|
|
)
|
|
|
|
_AUDIT_CANONICALIZE_FALLBACK = (
|
|
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
|
"These are names of people, companies, projects, or topics. "
|
|
"Group labels that clearly refer to the same real-world entity "
|
|
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
|
"Return ONLY a valid JSON array, no markdown fences: "
|
|
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
|
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
|
"Labels:\n{labels}"
|
|
)
|
|
|
|
|
|
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
|
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
|
|
|
Steps:
|
|
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
|
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
|
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
|
4. Rewrite variant labels to their canonical form in-place.
|
|
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _audit_memory_inner(db, user_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
|
|
|
|
|
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
|
return
|
|
|
|
fernet = Fernet(user.encryption_key.encode())
|
|
await _scan_associative_contradictions(db, user_id, fernet)
|
|
await _canonicalize_relation_labels(db, user_id)
|
|
|
|
|
|
async def _scan_associative_contradictions(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
fernet: Fernet,
|
|
) -> None:
|
|
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
|
result = await db.execute(
|
|
select(MemoryAssociative)
|
|
.where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc())
|
|
.limit(_AUDIT_MAX_FACTS)
|
|
)
|
|
rows = result.scalars().all()
|
|
if len(rows) < 2:
|
|
return
|
|
|
|
id_to_text: dict[str, str] = {}
|
|
for row in rows:
|
|
try:
|
|
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
|
id_to_text[row.id] = plaintext
|
|
except Exception:
|
|
pass
|
|
|
|
if len(id_to_text) < 2:
|
|
return
|
|
|
|
id_list = list(id_to_text.keys())
|
|
numbered = "\n".join(
|
|
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
|
)
|
|
|
|
template, prompt_obj = get_prompt_or_fallback(
|
|
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
|
)
|
|
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
|
|
|
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
|
|
llm = get_agent_llm("memory-auditor", temperature=0)
|
|
lf = get_langfuse()
|
|
messages = [
|
|
SystemMessage(content=system_text),
|
|
HumanMessage(content="Audit facts for contradictions."),
|
|
]
|
|
try:
|
|
if lf:
|
|
with lf.start_as_current_observation(
|
|
as_type="generation",
|
|
name="memory-audit-contradictions",
|
|
model=model_for_agent("memory-auditor"),
|
|
prompt=prompt_obj,
|
|
input=messages,
|
|
) as gen:
|
|
response = await llm.ainvoke(messages)
|
|
gen.update(output=response.content, usage=extract_usage(response))
|
|
else:
|
|
response = await llm.ainvoke(messages)
|
|
|
|
text = response.content if hasattr(response, "content") else str(response)
|
|
deletions = json.loads(text.strip())
|
|
if not isinstance(deletions, list):
|
|
return
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
|
user_id, exc,
|
|
)
|
|
return
|
|
|
|
deleted = 0
|
|
for item in deletions:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
rid = item.get("delete")
|
|
if not rid or rid not in id_to_text:
|
|
continue
|
|
result2 = await db.execute(
|
|
select(MemoryAssociative).where(
|
|
MemoryAssociative.id == rid,
|
|
MemoryAssociative.user_id == user_id,
|
|
)
|
|
)
|
|
target = result2.scalar_one_or_none()
|
|
if target:
|
|
await db.delete(target)
|
|
deleted += 1
|
|
logger.info(
|
|
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
|
rid, user_id, item.get("reason", ""),
|
|
)
|
|
|
|
if deleted:
|
|
try:
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
|
)
|
|
await db.rollback()
|
|
|
|
logger.info(
|
|
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
|
)
|
|
|
|
|
|
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
|
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
|
result = await db.execute(
|
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
if not rows:
|
|
return
|
|
|
|
all_labels: set[str] = set()
|
|
for row in rows:
|
|
all_labels.add(row.subject_label)
|
|
all_labels.add(row.object_label)
|
|
|
|
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
|
if len(labels_list) < 2:
|
|
return
|
|
|
|
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
|
template, prompt_obj = get_prompt_or_fallback(
|
|
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
|
)
|
|
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
|
|
|
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
|
|
llm = get_agent_llm("memory-auditor", temperature=0)
|
|
lf = get_langfuse()
|
|
messages = [
|
|
SystemMessage(content=system_text),
|
|
HumanMessage(content="Canonicalize entity labels."),
|
|
]
|
|
try:
|
|
if lf:
|
|
with lf.start_as_current_observation(
|
|
as_type="generation",
|
|
name="memory-audit-canonicalize",
|
|
model=model_for_agent("memory-auditor"),
|
|
prompt=prompt_obj,
|
|
input=messages,
|
|
) as gen:
|
|
response = await llm.ainvoke(messages)
|
|
gen.update(output=response.content, usage=extract_usage(response))
|
|
else:
|
|
response = await llm.ainvoke(messages)
|
|
|
|
text = response.content if hasattr(response, "content") else str(response)
|
|
groups = json.loads(text.strip())
|
|
if not isinstance(groups, list):
|
|
return
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
|
user_id, exc,
|
|
)
|
|
return
|
|
|
|
# Build variant → canonical map
|
|
remap: dict[str, str] = {}
|
|
for group in groups:
|
|
if not isinstance(group, dict):
|
|
continue
|
|
canonical = group.get("canonical", "")
|
|
variants = group.get("variants") or []
|
|
if not canonical:
|
|
continue
|
|
for v in variants:
|
|
if isinstance(v, str) and v != canonical:
|
|
remap[v] = canonical
|
|
|
|
if not remap:
|
|
return
|
|
|
|
updated = 0
|
|
for row in rows:
|
|
changed = False
|
|
if row.subject_label in remap:
|
|
row.subject_label = remap[row.subject_label]
|
|
changed = True
|
|
if row.object_label in remap:
|
|
row.object_label = remap[row.object_label]
|
|
changed = True
|
|
if changed:
|
|
updated += 1
|
|
|
|
if updated:
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
|
user_id, updated,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
|
)
|
|
await db.rollback()
|