406 lines
15 KiB
Python
406 lines
15 KiB
Python
"""Tests for Phase 7 — weekly audit_memory job.
|
|
|
|
Coverage:
|
|
1. audit_memory never raises even if inner work fails.
|
|
2. _scan_associative_contradictions skips when < 2 decryptable facts.
|
|
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
|
|
4. _scan_associative_contradictions is a no-op when LLM fails.
|
|
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
|
|
6. _canonicalize_relation_labels skips when no relation rows.
|
|
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
|
|
8. _canonicalize_relation_labels is a no-op when LLM fails.
|
|
9. _canonicalize_relation_labels is a no-op when remap is empty.
|
|
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
|
|
11. get_prompt_or_fallback called with correct Langfuse prompt names.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from contextlib import contextmanager, ExitStack
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
|
|
from app.core.memory_maintenance import (
|
|
_canonicalize_relation_labels,
|
|
_scan_associative_contradictions,
|
|
audit_memory,
|
|
)
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import MemoryAssociative, MemoryRelation, User
|
|
from tests.conftest import TEST_USER_IDS
|
|
|
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
|
_FERNET_KEY = Fernet.generate_key().decode()
|
|
_FERNET = Fernet(_FERNET_KEY.encode())
|
|
|
|
|
|
# ── DB override ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture
|
|
async def pro_user(db_session):
|
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
def _enc(text: str) -> str:
|
|
return _FERNET.encrypt(text.encode()).decode()
|
|
|
|
|
|
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
|
|
return MemoryAssociative(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
content_encrypted=_enc(text),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
|
|
return MemoryRelation(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
subject_label=subject,
|
|
subject_type="person",
|
|
predicate=predicate,
|
|
object_label=obj,
|
|
object_type="company",
|
|
confidence=0.8,
|
|
)
|
|
|
|
|
|
def _llm_response(content: str) -> MagicMock:
|
|
msg = MagicMock()
|
|
msg.content = content
|
|
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
return msg
|
|
|
|
|
|
def _mock_llm(content: str) -> MagicMock:
|
|
llm = MagicMock()
|
|
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
|
|
return llm
|
|
|
|
|
|
@contextmanager
|
|
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
|
|
"""Context manager that patches all external deps for audit helpers."""
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
|
|
)
|
|
stack.enter_context(
|
|
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
|
|
)
|
|
stack.enter_context(
|
|
patch(
|
|
"app.core.memory_maintenance.get_prompt_or_fallback",
|
|
return_value=(prompt_text, None),
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch(
|
|
"app.core.memory_maintenance.compile_prompt",
|
|
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
|
|
)
|
|
)
|
|
yield
|
|
|
|
|
|
# ── Test 1: audit_memory never raises ────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_audit_memory_never_raises_on_missing_user(db_session):
|
|
"""audit_memory with a non-existent user_id must not raise."""
|
|
await audit_memory(db_session, str(uuid.uuid4()))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
|
|
"""audit_memory must swallow inner exceptions."""
|
|
llm = MagicMock()
|
|
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
|
|
|
with (
|
|
patch("app.core.llm.get_agent_llm", return_value=llm),
|
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
|
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
|
patch(
|
|
"app.core.memory_maintenance.get_prompt_or_fallback",
|
|
return_value=("p {facts}", None),
|
|
),
|
|
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
|
):
|
|
await audit_memory(db_session, PRO_USER_ID)
|
|
|
|
|
|
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
|
|
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
|
db_session.add(row)
|
|
await db_session.commit()
|
|
|
|
llm = MagicMock()
|
|
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
|
|
|
with _patch_audit(llm):
|
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
|
|
|
llm.ainvoke.assert_not_called()
|
|
|
|
|
|
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
|
|
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
|
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
|
|
db_session.add(keep)
|
|
db_session.add(drop)
|
|
await db_session.commit()
|
|
|
|
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
|
|
llm = _mock_llm(deletion_payload)
|
|
|
|
with _patch_audit(llm, prompt_text="p {facts}"):
|
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
|
)
|
|
remaining = result.scalars().all()
|
|
remaining_ids = {r.id for r in remaining}
|
|
assert keep.id in remaining_ids
|
|
assert drop.id not in remaining_ids
|
|
|
|
|
|
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
|
|
for text in ("Fact A", "Fact B"):
|
|
db_session.add(_assoc_row(PRO_USER_ID, text))
|
|
await db_session.commit()
|
|
|
|
llm = MagicMock()
|
|
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
|
|
|
with _patch_audit(llm, prompt_text="p {facts}"):
|
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
|
)
|
|
assert len(result.scalars().all()) == 2
|
|
|
|
|
|
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
|
|
for text in ("Fact A", "Fact B"):
|
|
db_session.add(_assoc_row(PRO_USER_ID, text))
|
|
await db_session.commit()
|
|
|
|
llm = _mock_llm('"unexpected string"')
|
|
|
|
with _patch_audit(llm, prompt_text="p {facts}"):
|
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
|
)
|
|
assert len(result.scalars().all()) == 2
|
|
|
|
|
|
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
|
|
llm = MagicMock()
|
|
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
|
|
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
|
|
|
llm.ainvoke.assert_not_called()
|
|
|
|
|
|
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
|
|
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
|
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
|
|
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
|
|
db_session.add(row_a)
|
|
db_session.add(row_b)
|
|
db_session.add(row_c)
|
|
await db_session.commit()
|
|
|
|
groups = json.dumps([
|
|
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
|
|
])
|
|
llm = _mock_llm(groups)
|
|
|
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
|
|
|
await db_session.refresh(row_a)
|
|
await db_session.refresh(row_b)
|
|
await db_session.refresh(row_c)
|
|
|
|
assert row_a.subject_label == "Giulia"
|
|
assert row_b.subject_label == "Giulia"
|
|
assert row_c.object_label == "Giulia"
|
|
assert row_c.subject_label == "Marco"
|
|
|
|
|
|
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
|
|
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
|
db_session.add(row)
|
|
await db_session.commit()
|
|
|
|
llm = MagicMock()
|
|
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
|
|
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
|
|
|
await db_session.refresh(row)
|
|
assert row.subject_label == "giulia"
|
|
|
|
|
|
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
|
|
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
|
|
db_session.add(row)
|
|
await db_session.commit()
|
|
|
|
llm = _mock_llm("[]")
|
|
|
|
with _patch_audit(llm, prompt_text="p {labels}"):
|
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
|
|
|
await db_session.refresh(row)
|
|
assert row.subject_label == "Giulia"
|
|
|
|
|
|
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_works_without_langfuse(db_session, pro_user):
|
|
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
|
|
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
|
|
db_session.add(keep)
|
|
db_session.add(drop)
|
|
await db_session.commit()
|
|
|
|
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
|
|
llm = _mock_llm(deletion_payload)
|
|
|
|
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
|
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
|
)
|
|
remaining_ids = {r.id for r in result.scalars().all()}
|
|
assert keep.id in remaining_ids
|
|
assert drop.id not in remaining_ids
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
|
|
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
|
db_session.add(row)
|
|
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
|
|
await db_session.commit()
|
|
|
|
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
|
|
llm = _mock_llm(groups)
|
|
|
|
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
|
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
|
|
|
await db_session.refresh(row)
|
|
assert row.subject_label == "Giulia"
|
|
|
|
|
|
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
|
for text in ("Fact A", "Fact B"):
|
|
db_session.add(_assoc_row(PRO_USER_ID, text))
|
|
await db_session.commit()
|
|
|
|
llm = _mock_llm("[]")
|
|
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
|
|
|
|
with (
|
|
patch("app.core.llm.get_agent_llm", return_value=llm),
|
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
|
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
|
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
|
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
|
):
|
|
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
|
|
|
mock_get_prompt.assert_called_once()
|
|
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
|
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
|
|
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
|
|
await db_session.commit()
|
|
|
|
llm = _mock_llm("[]")
|
|
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
|
|
|
|
with (
|
|
patch("app.core.llm.get_agent_llm", return_value=llm),
|
|
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
|
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
|
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
|
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
|
):
|
|
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
|
|
|
mock_get_prompt.assert_called_once()
|
|
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"
|