Phase 7: audit memory
This commit is contained in:
405
tests/test_memory_audit.py
Normal file
405
tests/test_memory_audit.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""Tests for Phase 7 — weekly audit_memory job.
|
||||
|
||||
Coverage:
|
||||
1. audit_memory never raises even if inner work fails.
|
||||
2. _scan_associative_contradictions skips when < 2 decryptable facts.
|
||||
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
|
||||
4. _scan_associative_contradictions is a no-op when LLM fails.
|
||||
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
|
||||
6. _canonicalize_relation_labels skips when no relation rows.
|
||||
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
|
||||
8. _canonicalize_relation_labels is a no-op when LLM fails.
|
||||
9. _canonicalize_relation_labels is a no-op when remap is empty.
|
||||
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
|
||||
11. get_prompt_or_fallback called with correct Langfuse prompt names.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_maintenance import (
|
||||
_canonicalize_relation_labels,
|
||||
_scan_associative_contradictions,
|
||||
audit_memory,
|
||||
)
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryAssociative, MemoryRelation, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
_FERNET = Fernet(_FERNET_KEY.encode())
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user(db_session):
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _enc(text: str) -> str:
|
||||
return _FERNET.encrypt(text.encode()).decode()
|
||||
|
||||
|
||||
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
|
||||
return MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=_enc(text),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
|
||||
return MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type="person",
|
||||
predicate=predicate,
|
||||
object_label=obj,
|
||||
object_type="company",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
|
||||
def _llm_response(content: str) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
return msg
|
||||
|
||||
|
||||
def _mock_llm(content: str) -> MagicMock:
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
|
||||
return llm
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
|
||||
"""Context manager that patches all external deps for audit helpers."""
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||
return_value=(prompt_text, None),
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"app.core.memory_maintenance.compile_prompt",
|
||||
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
|
||||
)
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
# ── Test 1: audit_memory never raises ────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_memory_never_raises_on_missing_user(db_session):
|
||||
"""audit_memory with a non-existent user_id must not raise."""
|
||||
await audit_memory(db_session, str(uuid.uuid4()))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
|
||||
"""audit_memory must swallow inner exceptions."""
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch(
|
||||
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||
return_value=("p {facts}", None),
|
||||
),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await audit_memory(db_session, PRO_USER_ID)
|
||||
|
||||
|
||||
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
|
||||
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||
|
||||
with _patch_audit(llm):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
|
||||
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
|
||||
db_session.add(keep)
|
||||
db_session.add(drop)
|
||||
await db_session.commit()
|
||||
|
||||
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
|
||||
llm = _mock_llm(deletion_payload)
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
remaining = result.scalars().all()
|
||||
remaining_ids = {r.id for r in remaining}
|
||||
assert keep.id in remaining_ids
|
||||
assert drop.id not in remaining_ids
|
||||
|
||||
|
||||
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
|
||||
|
||||
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm('"unexpected string"')
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
|
||||
|
||||
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
|
||||
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
|
||||
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
|
||||
db_session.add(row_a)
|
||||
db_session.add(row_b)
|
||||
db_session.add(row_c)
|
||||
await db_session.commit()
|
||||
|
||||
groups = json.dumps([
|
||||
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
|
||||
])
|
||||
llm = _mock_llm(groups)
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row_a)
|
||||
await db_session.refresh(row_b)
|
||||
await db_session.refresh(row_c)
|
||||
|
||||
assert row_a.subject_label == "Giulia"
|
||||
assert row_b.subject_label == "Giulia"
|
||||
assert row_c.object_label == "Giulia"
|
||||
assert row_c.subject_label == "Marco"
|
||||
|
||||
|
||||
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "giulia"
|
||||
|
||||
|
||||
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "Giulia"
|
||||
|
||||
|
||||
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_works_without_langfuse(db_session, pro_user):
|
||||
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
|
||||
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
|
||||
db_session.add(keep)
|
||||
db_session.add(drop)
|
||||
await db_session.commit()
|
||||
|
||||
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
|
||||
llm = _mock_llm(deletion_payload)
|
||||
|
||||
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
remaining_ids = {r.id for r in result.scalars().all()}
|
||||
assert keep.id in remaining_ids
|
||||
assert drop.id not in remaining_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
|
||||
await db_session.commit()
|
||||
|
||||
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
|
||||
llm = _mock_llm(groups)
|
||||
|
||||
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "Giulia"
|
||||
|
||||
|
||||
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
mock_get_prompt.assert_called_once()
|
||||
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
mock_get_prompt.assert_called_once()
|
||||
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"
|
||||
Reference in New Issue
Block a user