Phase 7: audit memory
This commit is contained in:
@@ -105,6 +105,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ All are safe to call manually or from tests; they never raise.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -19,7 +20,8 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,6 +39,10 @@ _PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||
_MIN_EPISODES_FOR_MINING = 3
|
||||
_MINING_LOOKBACK_DAYS = 30
|
||||
|
||||
# Audit: caps to control token cost
|
||||
_AUDIT_MAX_FACTS = 50
|
||||
_AUDIT_MAX_LABELS = 100
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
"""Apply confidence decay to all relation rows for a user.
|
||||
@@ -311,3 +317,265 @@ async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fern
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||
|
||||
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||
"You are auditing a personal AI assistant's memory bank. "
|
||||
"Each fact has an ID in brackets. "
|
||||
"Find pairs that directly contradict each other "
|
||||
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||
'Return ONLY a valid JSON array, no markdown fences: '
|
||||
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||
"If no contradictions, return [].\n\n"
|
||||
"Facts:\n{facts}"
|
||||
)
|
||||
|
||||
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||
"These are names of people, companies, projects, or topics. "
|
||||
"Group labels that clearly refer to the same real-world entity "
|
||||
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||
"Return ONLY a valid JSON array, no markdown fences: "
|
||||
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||
"Labels:\n{labels}"
|
||||
)
|
||||
|
||||
|
||||
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||
|
||||
Steps:
|
||||
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||
4. Rewrite variant labels to their canonical form in-place.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _audit_memory_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
await _scan_associative_contradictions(db, user_id, fernet)
|
||||
await _canonicalize_relation_labels(db, user_id)
|
||||
|
||||
|
||||
async def _scan_associative_contradictions(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
) -> None:
|
||||
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||
result = await db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_AUDIT_MAX_FACTS)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if len(rows) < 2:
|
||||
return
|
||||
|
||||
id_to_text: dict[str, str] = {}
|
||||
for row in rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||
id_to_text[row.id] = plaintext
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if len(id_to_text) < 2:
|
||||
return
|
||||
|
||||
id_list = list(id_to_text.keys())
|
||||
numbered = "\n".join(
|
||||
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Audit facts for contradictions."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-contradictions",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
deletions = json.loads(text.strip())
|
||||
if not isinstance(deletions, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
deleted = 0
|
||||
for item in deletions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
rid = item.get("delete")
|
||||
if not rid or rid not in id_to_text:
|
||||
continue
|
||||
result2 = await db.execute(
|
||||
select(MemoryAssociative).where(
|
||||
MemoryAssociative.id == rid,
|
||||
MemoryAssociative.user_id == user_id,
|
||||
)
|
||||
)
|
||||
target = result2.scalar_one_or_none()
|
||||
if target:
|
||||
await db.delete(target)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||
rid, user_id, item.get("reason", ""),
|
||||
)
|
||||
|
||||
if deleted:
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info(
|
||||
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||
)
|
||||
|
||||
|
||||
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
all_labels: set[str] = set()
|
||||
for row in rows:
|
||||
all_labels.add(row.subject_label)
|
||||
all_labels.add(row.object_label)
|
||||
|
||||
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||
if len(labels_list) < 2:
|
||||
return
|
||||
|
||||
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Canonicalize entity labels."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-canonicalize",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
groups = json.loads(text.strip())
|
||||
if not isinstance(groups, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
# Build variant → canonical map
|
||||
remap: dict[str, str] = {}
|
||||
for group in groups:
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
canonical = group.get("canonical", "")
|
||||
variants = group.get("variants") or []
|
||||
if not canonical:
|
||||
continue
|
||||
for v in variants:
|
||||
if isinstance(v, str) and v != canonical:
|
||||
remap[v] = canonical
|
||||
|
||||
if not remap:
|
||||
return
|
||||
|
||||
updated = 0
|
||||
for row in rows:
|
||||
changed = False
|
||||
if row.subject_label in remap:
|
||||
row.subject_label = remap[row.subject_label]
|
||||
changed = True
|
||||
if row.object_label in remap:
|
||||
row.object_label = remap[row.object_label]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
|
||||
if updated:
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||
user_id, updated,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
Reference in New Issue
Block a user