Phase 7: audit memory

This commit is contained in:
Roberto Musso
2026-04-17 22:43:55 +02:00
parent ca8721e1ac
commit 0b5ef48463
6 changed files with 708 additions and 1 deletions

View File

@@ -61,6 +61,10 @@ LLM_MODEL_MEMORY_EXTRACTOR=
# Defaults to gpt-4o-mini when empty.
LLM_MODEL_MEMORY_MINER=
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
LLM_MODEL_MEMORY_AUDITOR=
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
SCHEDULER_ENABLED=true

View File

@@ -29,6 +29,7 @@ class Settings(BaseSettings):
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).

View File

@@ -105,6 +105,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
}

View File

@@ -11,6 +11,7 @@ All are safe to call manually or from tests; they never raise.
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
@@ -19,7 +20,8 @@ from cryptography.fernet import Fernet
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import MemoryEpisodic, MemoryProactive, MemoryRelation, User
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
logger = logging.getLogger(__name__)
@@ -37,6 +39,10 @@ _PROACTIVE_PRUNE_THRESHOLD = 0.2
_MIN_EPISODES_FOR_MINING = 3
_MINING_LOOKBACK_DAYS = 30
# Audit: caps to control token cost
_AUDIT_MAX_FACTS = 50
_AUDIT_MAX_LABELS = 100
async def decay_relations(db: AsyncSession, user_id: str) -> None:
"""Apply confidence decay to all relation rows for a user.
@@ -311,3 +317,265 @@ async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fern
except Exception as exc:
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
await db.rollback()
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
_AUDIT_CONTRADICTIONS_FALLBACK = (
"You are auditing a personal AI assistant's memory bank. "
"Each fact has an ID in brackets. "
"Find pairs that directly contradict each other "
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
'Return ONLY a valid JSON array, no markdown fences: '
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
"If no contradictions, return [].\n\n"
"Facts:\n{facts}"
)
_AUDIT_CANONICALIZE_FALLBACK = (
"You are auditing entity labels in a personal AI assistant's relational memory. "
"These are names of people, companies, projects, or topics. "
"Group labels that clearly refer to the same real-world entity "
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
"Return ONLY a valid JSON array, no markdown fences: "
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
"Only include groups with at least one variant. Singletons: omit.\n\n"
"Labels:\n{labels}"
)
async def audit_memory(db: AsyncSession, user_id: str) -> None:
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
Steps:
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
2. LLM flags rows to delete (direct contradictions); hard-delete them.
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
4. Rewrite variant labels to their canonical form in-place.
Never raises — wraps in try/except.
"""
try:
await _audit_memory_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
return
fernet = Fernet(user.encryption_key.encode())
await _scan_associative_contradictions(db, user_id, fernet)
await _canonicalize_relation_labels(db, user_id)
async def _scan_associative_contradictions(
db: AsyncSession,
user_id: str,
fernet: Fernet,
) -> None:
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
result = await db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_AUDIT_MAX_FACTS)
)
rows = result.scalars().all()
if len(rows) < 2:
return
id_to_text: dict[str, str] = {}
for row in rows:
try:
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
id_to_text[row.id] = plaintext
except Exception:
pass
if len(id_to_text) < 2:
return
id_list = list(id_to_text.keys())
numbered = "\n".join(
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
)
template, prompt_obj = get_prompt_or_fallback(
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
)
system_text = compile_prompt(template, prompt_obj, facts=numbered)
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
llm = get_agent_llm("memory-auditor", temperature=0)
lf = get_langfuse()
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Audit facts for contradictions."),
]
try:
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-audit-contradictions",
model=model_for_agent("memory-auditor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
text = response.content if hasattr(response, "content") else str(response)
deletions = json.loads(text.strip())
if not isinstance(deletions, list):
return
except Exception as exc:
logger.warning(
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
user_id, exc,
)
return
deleted = 0
for item in deletions:
if not isinstance(item, dict):
continue
rid = item.get("delete")
if not rid or rid not in id_to_text:
continue
result2 = await db.execute(
select(MemoryAssociative).where(
MemoryAssociative.id == rid,
MemoryAssociative.user_id == user_id,
)
)
target = result2.scalar_one_or_none()
if target:
await db.delete(target)
deleted += 1
logger.info(
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
rid, user_id, item.get("reason", ""),
)
if deleted:
try:
await db.commit()
except Exception as exc:
logger.warning(
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
)
await db.rollback()
logger.info(
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
)
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
if not rows:
return
all_labels: set[str] = set()
for row in rows:
all_labels.add(row.subject_label)
all_labels.add(row.object_label)
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
if len(labels_list) < 2:
return
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
template, prompt_obj = get_prompt_or_fallback(
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
)
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
llm = get_agent_llm("memory-auditor", temperature=0)
lf = get_langfuse()
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Canonicalize entity labels."),
]
try:
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-audit-canonicalize",
model=model_for_agent("memory-auditor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
text = response.content if hasattr(response, "content") else str(response)
groups = json.loads(text.strip())
if not isinstance(groups, list):
return
except Exception as exc:
logger.warning(
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
user_id, exc,
)
return
# Build variant → canonical map
remap: dict[str, str] = {}
for group in groups:
if not isinstance(group, dict):
continue
canonical = group.get("canonical", "")
variants = group.get("variants") or []
if not canonical:
continue
for v in variants:
if isinstance(v, str) and v != canonical:
remap[v] = canonical
if not remap:
return
updated = 0
for row in rows:
changed = False
if row.subject_label in remap:
row.subject_label = remap[row.subject_label]
changed = True
if row.object_label in remap:
row.object_label = remap[row.object_label]
changed = True
if changed:
updated += 1
if updated:
try:
await db.commit()
logger.info(
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
user_id, updated,
)
except Exception as exc:
logger.warning(
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
)
await db.rollback()

View File

@@ -16,6 +16,33 @@ from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
async def _memory_audit_cron_tick() -> None:
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
import logging # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("memory audit cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
from app.models import User # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as db:
result = await db.execute(select(User.id))
user_ids: list[str] = list(result.scalars().all())
for uid in user_ids:
try:
async with async_session() as db:
await audit_memory(db, uid)
except Exception as exc:
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
_log.info("memory audit cron tick: done users=%d", len(user_ids))
except Exception as exc:
_log.warning("memory audit cron tick: failed: %s", exc)
async def _memory_cron_tick() -> None:
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
import logging # noqa: PLC0415
@@ -61,6 +88,7 @@ async def lifespan(app: FastAPI):
scheduler = AsyncIOScheduler()
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
scheduler.start()
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")

405
tests/test_memory_audit.py Normal file
View File

@@ -0,0 +1,405 @@
"""Tests for Phase 7 — weekly audit_memory job.
Coverage:
1. audit_memory never raises even if inner work fails.
2. _scan_associative_contradictions skips when < 2 decryptable facts.
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
4. _scan_associative_contradictions is a no-op when LLM fails.
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
6. _canonicalize_relation_labels skips when no relation rows.
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
8. _canonicalize_relation_labels is a no-op when LLM fails.
9. _canonicalize_relation_labels is a no-op when remap is empty.
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
11. get_prompt_or_fallback called with correct Langfuse prompt names.
"""
from __future__ import annotations
import json
import uuid
from contextlib import contextmanager, ExitStack
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_maintenance import (
_canonicalize_relation_labels,
_scan_associative_contradictions,
audit_memory,
)
from app.db import get_session
from app.main import app
from app.models import MemoryAssociative, MemoryRelation, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
_FERNET_KEY = Fernet.generate_key().decode()
_FERNET = Fernet(_FERNET_KEY.encode())
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Helpers ───────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def pro_user(db_session):
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _enc(text: str) -> str:
return _FERNET.encrypt(text.encode()).decode()
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
return MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=_enc(text),
updated_at=datetime.now(timezone.utc),
)
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
return MemoryRelation(
id=str(uuid.uuid4()),
user_id=user_id,
subject_label=subject,
subject_type="person",
predicate=predicate,
object_label=obj,
object_type="company",
confidence=0.8,
)
def _llm_response(content: str) -> MagicMock:
msg = MagicMock()
msg.content = content
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
return msg
def _mock_llm(content: str) -> MagicMock:
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
return llm
@contextmanager
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
"""Context manager that patches all external deps for audit helpers."""
with ExitStack() as stack:
stack.enter_context(
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
)
stack.enter_context(
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
)
stack.enter_context(
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
)
stack.enter_context(
patch(
"app.core.memory_maintenance.get_prompt_or_fallback",
return_value=(prompt_text, None),
)
)
stack.enter_context(
patch(
"app.core.memory_maintenance.compile_prompt",
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
)
)
yield
# ── Test 1: audit_memory never raises ────────────────────────────────────────
@pytest.mark.asyncio
async def test_audit_memory_never_raises_on_missing_user(db_session):
"""audit_memory with a non-existent user_id must not raise."""
await audit_memory(db_session, str(uuid.uuid4()))
@pytest.mark.asyncio
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
"""audit_memory must swallow inner exceptions."""
llm = MagicMock()
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
with (
patch("app.core.llm.get_agent_llm", return_value=llm),
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
patch(
"app.core.memory_maintenance.get_prompt_or_fallback",
return_value=("p {facts}", None),
),
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
):
await audit_memory(db_session, PRO_USER_ID)
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
db_session.add(row)
await db_session.commit()
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
with _patch_audit(llm):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
llm.ainvoke.assert_not_called()
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
db_session.add(keep)
db_session.add(drop)
await db_session.commit()
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
llm = _mock_llm(deletion_payload)
with _patch_audit(llm, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
remaining = result.scalars().all()
remaining_ids = {r.id for r in remaining}
assert keep.id in remaining_ids
assert drop.id not in remaining_ids
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
for text in ("Fact A", "Fact B"):
db_session.add(_assoc_row(PRO_USER_ID, text))
await db_session.commit()
llm = MagicMock()
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
with _patch_audit(llm, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
assert len(result.scalars().all()) == 2
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
for text in ("Fact A", "Fact B"):
db_session.add(_assoc_row(PRO_USER_ID, text))
await db_session.commit()
llm = _mock_llm('"unexpected string"')
with _patch_audit(llm, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
assert len(result.scalars().all()) == 2
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
@pytest.mark.asyncio
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
llm.ainvoke.assert_not_called()
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
@pytest.mark.asyncio
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
db_session.add(row_a)
db_session.add(row_b)
db_session.add(row_c)
await db_session.commit()
groups = json.dumps([
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
])
llm = _mock_llm(groups)
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row_a)
await db_session.refresh(row_b)
await db_session.refresh(row_c)
assert row_a.subject_label == "Giulia"
assert row_b.subject_label == "Giulia"
assert row_c.object_label == "Giulia"
assert row_c.subject_label == "Marco"
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
@pytest.mark.asyncio
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
db_session.add(row)
await db_session.commit()
llm = MagicMock()
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row)
assert row.subject_label == "giulia"
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
@pytest.mark.asyncio
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
db_session.add(row)
await db_session.commit()
llm = _mock_llm("[]")
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row)
assert row.subject_label == "Giulia"
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
@pytest.mark.asyncio
async def test_scan_works_without_langfuse(db_session, pro_user):
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
db_session.add(keep)
db_session.add(drop)
await db_session.commit()
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
llm = _mock_llm(deletion_payload)
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
remaining_ids = {r.id for r in result.scalars().all()}
assert keep.id in remaining_ids
assert drop.id not in remaining_ids
@pytest.mark.asyncio
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
db_session.add(row)
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
await db_session.commit()
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
llm = _mock_llm(groups)
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row)
assert row.subject_label == "Giulia"
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
@pytest.mark.asyncio
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
for text in ("Fact A", "Fact B"):
db_session.add(_assoc_row(PRO_USER_ID, text))
await db_session.commit()
llm = _mock_llm("[]")
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
with (
patch("app.core.llm.get_agent_llm", return_value=llm),
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
mock_get_prompt.assert_called_once()
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
@pytest.mark.asyncio
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
await db_session.commit()
llm = _mock_llm("[]")
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
with (
patch("app.core.llm.get_agent_llm", return_value=llm),
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
mock_get_prompt.assert_called_once()
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"