From 0b5ef484630d5836e36520abb2b9518dbf59a195 Mon Sep 17 00:00:00 2001 From: Roberto Musso Date: Fri, 17 Apr 2026 22:43:55 +0200 Subject: [PATCH] Phase 7: audit memory --- .env.example | 4 + app/config/settings.py | 1 + app/core/llm.py | 1 + app/core/memory_maintenance.py | 270 +++++++++++++++++++++- app/main.py | 28 +++ tests/test_memory_audit.py | 405 +++++++++++++++++++++++++++++++++ 6 files changed, 708 insertions(+), 1 deletion(-) create mode 100644 tests/test_memory_audit.py diff --git a/.env.example b/.env.example index 3149a72..3c9e0f3 100644 --- a/.env.example +++ b/.env.example @@ -61,6 +61,10 @@ LLM_MODEL_MEMORY_EXTRACTOR= # Defaults to gpt-4o-mini when empty. LLM_MODEL_MEMORY_MINER= +# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7). +# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended). +LLM_MODEL_MEMORY_AUDITOR= + # Scheduler — set to false to disable memory cron jobs (automatically false in tests). SCHEDULER_ENABLED=true diff --git a/app/config/settings.py b/app/config/settings.py index ba684ca..ebba918 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -29,6 +29,7 @@ class Settings(BaseSettings): LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide) LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining) + LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit) # GitHub Copilot OAuth token storage directory. # Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot). diff --git a/app/core/llm.py b/app/core/llm.py index 7bd566b..5ccbf9a 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -105,6 +105,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = { "setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL, "memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini", "memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini", + "memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL, } diff --git a/app/core/memory_maintenance.py b/app/core/memory_maintenance.py index 9e1db7d..2269478 100644 --- a/app/core/memory_maintenance.py +++ b/app/core/memory_maintenance.py @@ -11,6 +11,7 @@ All are safe to call manually or from tests; they never raise. from __future__ import annotations +import json import logging import uuid from datetime import datetime, timedelta, timezone @@ -19,7 +20,8 @@ from cryptography.fernet import Fernet from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.models import MemoryEpisodic, MemoryProactive, MemoryRelation, User +from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback +from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User logger = logging.getLogger(__name__) @@ -37,6 +39,10 @@ _PROACTIVE_PRUNE_THRESHOLD = 0.2 _MIN_EPISODES_FOR_MINING = 3 _MINING_LOOKBACK_DAYS = 30 +# Audit: caps to control token cost +_AUDIT_MAX_FACTS = 50 +_AUDIT_MAX_LABELS = 100 + async def decay_relations(db: AsyncSession, user_id: str) -> None: """Apply confidence decay to all relation rows for a user. @@ -311,3 +317,265 @@ async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fern except Exception as exc: logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc) await db.rollback() + + +# ── Phase 7: weekly memory audit ────────────────────────────────────────────── + +_AUDIT_CONTRADICTIONS_FALLBACK = ( + "You are auditing a personal AI assistant's memory bank. " + "Each fact has an ID in brackets. " + "Find pairs that directly contradict each other " + "(e.g. 'prefers morning meetings' vs 'never schedules before noon'). " + "For each contradiction, pick the ID to DELETE (the older or less specific one). " + 'Return ONLY a valid JSON array, no markdown fences: ' + '[{{"delete": "", "reason": ""}}]. ' + "If no contradictions, return [].\n\n" + "Facts:\n{facts}" +) + +_AUDIT_CANONICALIZE_FALLBACK = ( + "You are auditing entity labels in a personal AI assistant's relational memory. " + "These are names of people, companies, projects, or topics. " + "Group labels that clearly refer to the same real-world entity " + "(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). " + "Return ONLY a valid JSON array, no markdown fences: " + '[{{"canonical": "", "variants": ["", ""]}}]. ' + "Only include groups with at least one variant. Singletons: omit.\n\n" + "Labels:\n{labels}" +) + + +async def audit_memory(db: AsyncSession, user_id: str) -> None: + """Weekly audit: contradiction scan on associative facts + label canonicalization on relations. + + Steps: + 1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM. + 2. LLM flags rows to delete (direct contradictions); hard-delete them. + 3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates. + 4. Rewrite variant labels to their canonical form in-place. + + Never raises — wraps in try/except. + """ + try: + await _audit_memory_inner(db, user_id) + except Exception as exc: + logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc) + + +async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None: + result = await db.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + if user is None or not user.encryption_key: + logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id) + return + + fernet = Fernet(user.encryption_key.encode()) + await _scan_associative_contradictions(db, user_id, fernet) + await _canonicalize_relation_labels(db, user_id) + + +async def _scan_associative_contradictions( + db: AsyncSession, + user_id: str, + fernet: Fernet, +) -> None: + """Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows.""" + result = await db.execute( + select(MemoryAssociative) + .where(MemoryAssociative.user_id == user_id) + .order_by(MemoryAssociative.updated_at.desc()) + .limit(_AUDIT_MAX_FACTS) + ) + rows = result.scalars().all() + if len(rows) < 2: + return + + id_to_text: dict[str, str] = {} + for row in rows: + try: + plaintext = fernet.decrypt(row.content_encrypted.encode()).decode() + id_to_text[row.id] = plaintext + except Exception: + pass + + if len(id_to_text) < 2: + return + + id_list = list(id_to_text.keys()) + numbered = "\n".join( + f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list) + ) + + template, prompt_obj = get_prompt_or_fallback( + "memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK + ) + system_text = compile_prompt(template, prompt_obj, facts=numbered) + + from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415 + from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 + + llm = get_agent_llm("memory-auditor", temperature=0) + lf = get_langfuse() + messages = [ + SystemMessage(content=system_text), + HumanMessage(content="Audit facts for contradictions."), + ] + try: + if lf: + with lf.start_as_current_observation( + as_type="generation", + name="memory-audit-contradictions", + model=model_for_agent("memory-auditor"), + prompt=prompt_obj, + input=messages, + ) as gen: + response = await llm.ainvoke(messages) + gen.update(output=response.content, usage=extract_usage(response)) + else: + response = await llm.ainvoke(messages) + + text = response.content if hasattr(response, "content") else str(response) + deletions = json.loads(text.strip()) + if not isinstance(deletions, list): + return + except Exception as exc: + logger.warning( + "memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s", + user_id, exc, + ) + return + + deleted = 0 + for item in deletions: + if not isinstance(item, dict): + continue + rid = item.get("delete") + if not rid or rid not in id_to_text: + continue + result2 = await db.execute( + select(MemoryAssociative).where( + MemoryAssociative.id == rid, + MemoryAssociative.user_id == user_id, + ) + ) + target = result2.scalar_one_or_none() + if target: + await db.delete(target) + deleted += 1 + logger.info( + "memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s", + rid, user_id, item.get("reason", ""), + ) + + if deleted: + try: + await db.commit() + except Exception as exc: + logger.warning( + "memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc + ) + await db.rollback() + + logger.info( + "memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted + ) + + +async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None: + """Group near-duplicate entity labels in memory_relations and unify to canonical form.""" + result = await db.execute( + select(MemoryRelation).where(MemoryRelation.user_id == user_id) + ) + rows = result.scalars().all() + if not rows: + return + + all_labels: set[str] = set() + for row in rows: + all_labels.add(row.subject_label) + all_labels.add(row.object_label) + + labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS] + if len(labels_list) < 2: + return + + labels_block = "\n".join(f"- {lbl}" for lbl in labels_list) + template, prompt_obj = get_prompt_or_fallback( + "memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK + ) + system_text = compile_prompt(template, prompt_obj, labels=labels_block) + + from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415 + from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 + + llm = get_agent_llm("memory-auditor", temperature=0) + lf = get_langfuse() + messages = [ + SystemMessage(content=system_text), + HumanMessage(content="Canonicalize entity labels."), + ] + try: + if lf: + with lf.start_as_current_observation( + as_type="generation", + name="memory-audit-canonicalize", + model=model_for_agent("memory-auditor"), + prompt=prompt_obj, + input=messages, + ) as gen: + response = await llm.ainvoke(messages) + gen.update(output=response.content, usage=extract_usage(response)) + else: + response = await llm.ainvoke(messages) + + text = response.content if hasattr(response, "content") else str(response) + groups = json.loads(text.strip()) + if not isinstance(groups, list): + return + except Exception as exc: + logger.warning( + "memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s", + user_id, exc, + ) + return + + # Build variant → canonical map + remap: dict[str, str] = {} + for group in groups: + if not isinstance(group, dict): + continue + canonical = group.get("canonical", "") + variants = group.get("variants") or [] + if not canonical: + continue + for v in variants: + if isinstance(v, str) and v != canonical: + remap[v] = canonical + + if not remap: + return + + updated = 0 + for row in rows: + changed = False + if row.subject_label in remap: + row.subject_label = remap[row.subject_label] + changed = True + if row.object_label in remap: + row.object_label = remap[row.object_label] + changed = True + if changed: + updated += 1 + + if updated: + try: + await db.commit() + logger.info( + "memory_maintenance: _canonicalize_relation_labels user=%s updated=%d", + user_id, updated, + ) + except Exception as exc: + logger.warning( + "memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc + ) + await db.rollback() diff --git a/app/main.py b/app/main.py index 56d5815..b3c9b8e 100644 --- a/app/main.py +++ b/app/main.py @@ -16,6 +16,33 @@ from app.api.middleware.sanitizer import SanitizerMiddleware from app.config.settings import settings +async def _memory_audit_cron_tick() -> None: + """Weekly cron: contradiction scan + label canonicalization for all users (Phase 7).""" + import logging # noqa: PLC0415 + _log = logging.getLogger(__name__) + _log.info("memory audit cron tick: starting") + try: + from app.db import async_session # noqa: PLC0415 + from app.core.memory_maintenance import audit_memory # noqa: PLC0415 + from app.models import User # noqa: PLC0415 + from sqlalchemy import select # noqa: PLC0415 + + async with async_session() as db: + result = await db.execute(select(User.id)) + user_ids: list[str] = list(result.scalars().all()) + + for uid in user_ids: + try: + async with async_session() as db: + await audit_memory(db, uid) + except Exception as exc: + _log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc) + + _log.info("memory audit cron tick: done users=%d", len(user_ids)) + except Exception as exc: + _log.warning("memory audit cron tick: failed: %s", exc) + + async def _memory_cron_tick() -> None: """Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users.""" import logging # noqa: PLC0415 @@ -61,6 +88,7 @@ async def lifespan(app: FastAPI): scheduler = AsyncIOScheduler() scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron") + scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron") scheduler.start() logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)") diff --git a/tests/test_memory_audit.py b/tests/test_memory_audit.py new file mode 100644 index 0000000..ab5c50b --- /dev/null +++ b/tests/test_memory_audit.py @@ -0,0 +1,405 @@ +"""Tests for Phase 7 — weekly audit_memory job. + +Coverage: + 1. audit_memory never raises even if inner work fails. + 2. _scan_associative_contradictions skips when < 2 decryptable facts. + 3. _scan_associative_contradictions calls LLM and deletes flagged rows. + 4. _scan_associative_contradictions is a no-op when LLM fails. + 5. _scan_associative_contradictions is a no-op when LLM returns non-list. + 6. _canonicalize_relation_labels skips when no relation rows. + 7. _canonicalize_relation_labels rewrites variant labels to canonical form. + 8. _canonicalize_relation_labels is a no-op when LLM fails. + 9. _canonicalize_relation_labels is a no-op when remap is empty. + 10. Both helpers work correctly when Langfuse is unavailable (lf=None). + 11. get_prompt_or_fallback called with correct Langfuse prompt names. +""" + +from __future__ import annotations + +import json +import uuid +from contextlib import contextmanager, ExitStack +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from cryptography.fernet import Fernet +from sqlalchemy import select + +from app.core.memory_maintenance import ( + _canonicalize_relation_labels, + _scan_associative_contradictions, + audit_memory, +) +from app.db import get_session +from app.main import app +from app.models import MemoryAssociative, MemoryRelation, User +from tests.conftest import TEST_USER_IDS + +PRO_USER_ID = TEST_USER_IDS["pro"] +_FERNET_KEY = Fernet.generate_key().decode() +_FERNET = Fernet(_FERNET_KEY.encode()) + + +# ── DB override ─────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _override_db(db_session): + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +@pytest_asyncio.fixture +async def pro_user(db_session): + result = await db_session.execute(select(User).where(User.id == PRO_USER_ID)) + user = result.scalar_one() + user.encryption_key = _FERNET_KEY + await db_session.commit() + return user + + +def _enc(text: str) -> str: + return _FERNET.encrypt(text.encode()).decode() + + +def _assoc_row(user_id: str, text: str) -> MemoryAssociative: + return MemoryAssociative( + id=str(uuid.uuid4()), + user_id=user_id, + content_encrypted=_enc(text), + updated_at=datetime.now(timezone.utc), + ) + + +def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation: + return MemoryRelation( + id=str(uuid.uuid4()), + user_id=user_id, + subject_label=subject, + subject_type="person", + predicate=predicate, + object_label=obj, + object_type="company", + confidence=0.8, + ) + + +def _llm_response(content: str) -> MagicMock: + msg = MagicMock() + msg.content = content + msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + return msg + + +def _mock_llm(content: str) -> MagicMock: + llm = MagicMock() + llm.ainvoke = AsyncMock(return_value=_llm_response(content)) + return llm + + +@contextmanager +def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"): + """Context manager that patches all external deps for audit helpers.""" + with ExitStack() as stack: + stack.enter_context( + patch("app.core.llm.get_agent_llm", return_value=llm_mock) + ) + stack.enter_context( + patch("app.core.llm.model_for_agent", return_value="memory-auditor") + ) + stack.enter_context( + patch("app.core.memory_maintenance.get_langfuse", return_value=lf) + ) + stack.enter_context( + patch( + "app.core.memory_maintenance.get_prompt_or_fallback", + return_value=(prompt_text, None), + ) + ) + stack.enter_context( + patch( + "app.core.memory_maintenance.compile_prompt", + side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl, + ) + ) + yield + + +# ── Test 1: audit_memory never raises ──────────────────────────────────────── + +@pytest.mark.asyncio +async def test_audit_memory_never_raises_on_missing_user(db_session): + """audit_memory with a non-existent user_id must not raise.""" + await audit_memory(db_session, str(uuid.uuid4())) + + +@pytest.mark.asyncio +async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user): + """audit_memory must swallow inner exceptions.""" + llm = MagicMock() + llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) + + with ( + patch("app.core.llm.get_agent_llm", return_value=llm), + patch("app.core.llm.model_for_agent", return_value="memory-auditor"), + patch("app.core.memory_maintenance.get_langfuse", return_value=None), + patch( + "app.core.memory_maintenance.get_prompt_or_fallback", + return_value=("p {facts}", None), + ), + patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"), + ): + await audit_memory(db_session, PRO_USER_ID) + + +# ── Test 2: _scan skips when < 2 facts ─────────────────────────────────────── + +@pytest.mark.asyncio +async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user): + row = _assoc_row(PRO_USER_ID, "Prefers morning meetings") + db_session.add(row) + await db_session.commit() + + llm = MagicMock() + llm.ainvoke = AsyncMock(return_value=_llm_response("[]")) + + with _patch_audit(llm): + await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET) + + llm.ainvoke.assert_not_called() + + +# ── Test 3: _scan deletes flagged contradiction ─────────────────────────────── + +@pytest.mark.asyncio +async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user): + keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings") + drop = _assoc_row(PRO_USER_ID, "Never schedules before noon") + db_session.add(keep) + db_session.add(drop) + await db_session.commit() + + deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}]) + llm = _mock_llm(deletion_payload) + + with _patch_audit(llm, prompt_text="p {facts}"): + await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET) + + result = await db_session.execute( + select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID) + ) + remaining = result.scalars().all() + remaining_ids = {r.id for r in remaining} + assert keep.id in remaining_ids + assert drop.id not in remaining_ids + + +# ── Test 4: _scan is no-op on LLM failure ──────────────────────────────────── + +@pytest.mark.asyncio +async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user): + for text in ("Fact A", "Fact B"): + db_session.add(_assoc_row(PRO_USER_ID, text)) + await db_session.commit() + + llm = MagicMock() + llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) + + with _patch_audit(llm, prompt_text="p {facts}"): + await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET) + + result = await db_session.execute( + select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID) + ) + assert len(result.scalars().all()) == 2 + + +# ── Test 5: _scan is no-op when LLM returns non-list ───────────────────────── + +@pytest.mark.asyncio +async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user): + for text in ("Fact A", "Fact B"): + db_session.add(_assoc_row(PRO_USER_ID, text)) + await db_session.commit() + + llm = _mock_llm('"unexpected string"') + + with _patch_audit(llm, prompt_text="p {facts}"): + await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET) + + result = await db_session.execute( + select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID) + ) + assert len(result.scalars().all()) == 2 + + +# ── Test 6: _canonicalize skips when no relations ──────────────────────────── + +@pytest.mark.asyncio +async def test_canonicalize_skips_when_no_relations(db_session, pro_user): + llm = MagicMock() + llm.ainvoke = AsyncMock(return_value=_llm_response("[]")) + + with _patch_audit(llm, prompt_text="p {labels}"): + await _canonicalize_relation_labels(db_session, PRO_USER_ID) + + llm.ainvoke.assert_not_called() + + +# ── Test 7: _canonicalize rewrites variant labels ──────────────────────────── + +@pytest.mark.asyncio +async def test_canonicalize_rewrites_variant_labels(db_session, pro_user): + row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme") + row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco") + row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia") + db_session.add(row_a) + db_session.add(row_b) + db_session.add(row_c) + await db_session.commit() + + groups = json.dumps([ + {"canonical": "Giulia", "variants": ["giulia", "Giulia R."]} + ]) + llm = _mock_llm(groups) + + with _patch_audit(llm, prompt_text="p {labels}"): + await _canonicalize_relation_labels(db_session, PRO_USER_ID) + + await db_session.refresh(row_a) + await db_session.refresh(row_b) + await db_session.refresh(row_c) + + assert row_a.subject_label == "Giulia" + assert row_b.subject_label == "Giulia" + assert row_c.object_label == "Giulia" + assert row_c.subject_label == "Marco" + + +# ── Test 8: _canonicalize is no-op on LLM failure ──────────────────────────── + +@pytest.mark.asyncio +async def test_canonicalize_noop_on_llm_failure(db_session, pro_user): + row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme") + db_session.add(row) + await db_session.commit() + + llm = MagicMock() + llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) + + with _patch_audit(llm, prompt_text="p {labels}"): + await _canonicalize_relation_labels(db_session, PRO_USER_ID) + + await db_session.refresh(row) + assert row.subject_label == "giulia" + + +# ── Test 9: _canonicalize is no-op when remap is empty ─────────────────────── + +@pytest.mark.asyncio +async def test_canonicalize_noop_when_remap_empty(db_session, pro_user): + row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme") + db_session.add(row) + await db_session.commit() + + llm = _mock_llm("[]") + + with _patch_audit(llm, prompt_text="p {labels}"): + await _canonicalize_relation_labels(db_session, PRO_USER_ID) + + await db_session.refresh(row) + assert row.subject_label == "Giulia" + + +# ── Test 10: both helpers work without Langfuse ─────────────────────────────── + +@pytest.mark.asyncio +async def test_scan_works_without_langfuse(db_session, pro_user): + keep = _assoc_row(PRO_USER_ID, "Prefers dark mode") + drop = _assoc_row(PRO_USER_ID, "Prefers light mode") + db_session.add(keep) + db_session.add(drop) + await db_session.commit() + + deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}]) + llm = _mock_llm(deletion_payload) + + with _patch_audit(llm, lf=None, prompt_text="p {facts}"): + await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET) + + result = await db_session.execute( + select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID) + ) + remaining_ids = {r.id for r in result.scalars().all()} + assert keep.id in remaining_ids + assert drop.id not in remaining_ids + + +@pytest.mark.asyncio +async def test_canonicalize_works_without_langfuse(db_session, pro_user): + row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme") + db_session.add(row) + db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")) + await db_session.commit() + + groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}]) + llm = _mock_llm(groups) + + with _patch_audit(llm, lf=None, prompt_text="p {labels}"): + await _canonicalize_relation_labels(db_session, PRO_USER_ID) + + await db_session.refresh(row) + assert row.subject_label == "Giulia" + + +# ── Test 11: correct Langfuse prompt names used ─────────────────────────────── + +@pytest.mark.asyncio +async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user): + for text in ("Fact A", "Fact B"): + db_session.add(_assoc_row(PRO_USER_ID, text)) + await db_session.commit() + + llm = _mock_llm("[]") + mock_get_prompt = MagicMock(return_value=("p {facts}", None)) + + with ( + patch("app.core.llm.get_agent_llm", return_value=llm), + patch("app.core.llm.model_for_agent", return_value="memory-auditor"), + patch("app.core.memory_maintenance.get_langfuse", return_value=None), + patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt), + patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"), + ): + await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET) + + mock_get_prompt.assert_called_once() + assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions" + + +@pytest.mark.asyncio +async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user): + db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")) + db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme")) + await db_session.commit() + + llm = _mock_llm("[]") + mock_get_prompt = MagicMock(return_value=("p {labels}", None)) + + with ( + patch("app.core.llm.get_agent_llm", return_value=llm), + patch("app.core.llm.model_for_agent", return_value="memory-auditor"), + patch("app.core.memory_maintenance.get_langfuse", return_value=None), + patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt), + patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"), + ): + await _canonicalize_relation_labels(db_session, PRO_USER_ID) + + mock_get_prompt.assert_called_once() + assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"