Phase 7: audit memory
This commit is contained in:
@@ -61,6 +61,10 @@ LLM_MODEL_MEMORY_EXTRACTOR=
|
||||
# Defaults to gpt-4o-mini when empty.
|
||||
LLM_MODEL_MEMORY_MINER=
|
||||
|
||||
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
||||
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
||||
LLM_MODEL_MEMORY_AUDITOR=
|
||||
|
||||
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||
SCHEDULER_ENABLED=true
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class Settings(BaseSettings):
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
|
||||
@@ -105,6 +105,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ All are safe to call manually or from tests; they never raise.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -19,7 +20,8 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,6 +39,10 @@ _PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||
_MIN_EPISODES_FOR_MINING = 3
|
||||
_MINING_LOOKBACK_DAYS = 30
|
||||
|
||||
# Audit: caps to control token cost
|
||||
_AUDIT_MAX_FACTS = 50
|
||||
_AUDIT_MAX_LABELS = 100
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
"""Apply confidence decay to all relation rows for a user.
|
||||
@@ -311,3 +317,265 @@ async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fern
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||
|
||||
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||
"You are auditing a personal AI assistant's memory bank. "
|
||||
"Each fact has an ID in brackets. "
|
||||
"Find pairs that directly contradict each other "
|
||||
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||
'Return ONLY a valid JSON array, no markdown fences: '
|
||||
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||
"If no contradictions, return [].\n\n"
|
||||
"Facts:\n{facts}"
|
||||
)
|
||||
|
||||
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||
"These are names of people, companies, projects, or topics. "
|
||||
"Group labels that clearly refer to the same real-world entity "
|
||||
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||
"Return ONLY a valid JSON array, no markdown fences: "
|
||||
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||
"Labels:\n{labels}"
|
||||
)
|
||||
|
||||
|
||||
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||
|
||||
Steps:
|
||||
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||
4. Rewrite variant labels to their canonical form in-place.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _audit_memory_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
await _scan_associative_contradictions(db, user_id, fernet)
|
||||
await _canonicalize_relation_labels(db, user_id)
|
||||
|
||||
|
||||
async def _scan_associative_contradictions(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
) -> None:
|
||||
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||
result = await db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_AUDIT_MAX_FACTS)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if len(rows) < 2:
|
||||
return
|
||||
|
||||
id_to_text: dict[str, str] = {}
|
||||
for row in rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||
id_to_text[row.id] = plaintext
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if len(id_to_text) < 2:
|
||||
return
|
||||
|
||||
id_list = list(id_to_text.keys())
|
||||
numbered = "\n".join(
|
||||
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Audit facts for contradictions."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-contradictions",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
deletions = json.loads(text.strip())
|
||||
if not isinstance(deletions, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
deleted = 0
|
||||
for item in deletions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
rid = item.get("delete")
|
||||
if not rid or rid not in id_to_text:
|
||||
continue
|
||||
result2 = await db.execute(
|
||||
select(MemoryAssociative).where(
|
||||
MemoryAssociative.id == rid,
|
||||
MemoryAssociative.user_id == user_id,
|
||||
)
|
||||
)
|
||||
target = result2.scalar_one_or_none()
|
||||
if target:
|
||||
await db.delete(target)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||
rid, user_id, item.get("reason", ""),
|
||||
)
|
||||
|
||||
if deleted:
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info(
|
||||
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||
)
|
||||
|
||||
|
||||
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
all_labels: set[str] = set()
|
||||
for row in rows:
|
||||
all_labels.add(row.subject_label)
|
||||
all_labels.add(row.object_label)
|
||||
|
||||
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||
if len(labels_list) < 2:
|
||||
return
|
||||
|
||||
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Canonicalize entity labels."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-canonicalize",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
groups = json.loads(text.strip())
|
||||
if not isinstance(groups, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
# Build variant → canonical map
|
||||
remap: dict[str, str] = {}
|
||||
for group in groups:
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
canonical = group.get("canonical", "")
|
||||
variants = group.get("variants") or []
|
||||
if not canonical:
|
||||
continue
|
||||
for v in variants:
|
||||
if isinstance(v, str) and v != canonical:
|
||||
remap[v] = canonical
|
||||
|
||||
if not remap:
|
||||
return
|
||||
|
||||
updated = 0
|
||||
for row in rows:
|
||||
changed = False
|
||||
if row.subject_label in remap:
|
||||
row.subject_label = remap[row.subject_label]
|
||||
changed = True
|
||||
if row.object_label in remap:
|
||||
row.object_label = remap[row.object_label]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
|
||||
if updated:
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||
user_id, updated,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
28
app/main.py
28
app/main.py
@@ -16,6 +16,33 @@ from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
async def _memory_audit_cron_tick() -> None:
|
||||
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory audit cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await audit_memory(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _memory_cron_tick() -> None:
|
||||
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||
import logging # noqa: PLC0415
|
||||
@@ -61,6 +88,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||
scheduler.start()
|
||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||
|
||||
|
||||
405
tests/test_memory_audit.py
Normal file
405
tests/test_memory_audit.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""Tests for Phase 7 — weekly audit_memory job.
|
||||
|
||||
Coverage:
|
||||
1. audit_memory never raises even if inner work fails.
|
||||
2. _scan_associative_contradictions skips when < 2 decryptable facts.
|
||||
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
|
||||
4. _scan_associative_contradictions is a no-op when LLM fails.
|
||||
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
|
||||
6. _canonicalize_relation_labels skips when no relation rows.
|
||||
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
|
||||
8. _canonicalize_relation_labels is a no-op when LLM fails.
|
||||
9. _canonicalize_relation_labels is a no-op when remap is empty.
|
||||
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
|
||||
11. get_prompt_or_fallback called with correct Langfuse prompt names.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_maintenance import (
|
||||
_canonicalize_relation_labels,
|
||||
_scan_associative_contradictions,
|
||||
audit_memory,
|
||||
)
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryAssociative, MemoryRelation, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
_FERNET = Fernet(_FERNET_KEY.encode())
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user(db_session):
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _enc(text: str) -> str:
|
||||
return _FERNET.encrypt(text.encode()).decode()
|
||||
|
||||
|
||||
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
|
||||
return MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=_enc(text),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
|
||||
return MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type="person",
|
||||
predicate=predicate,
|
||||
object_label=obj,
|
||||
object_type="company",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
|
||||
def _llm_response(content: str) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
return msg
|
||||
|
||||
|
||||
def _mock_llm(content: str) -> MagicMock:
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
|
||||
return llm
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
|
||||
"""Context manager that patches all external deps for audit helpers."""
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||
return_value=(prompt_text, None),
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"app.core.memory_maintenance.compile_prompt",
|
||||
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
|
||||
)
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
# ── Test 1: audit_memory never raises ────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_memory_never_raises_on_missing_user(db_session):
|
||||
"""audit_memory with a non-existent user_id must not raise."""
|
||||
await audit_memory(db_session, str(uuid.uuid4()))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
|
||||
"""audit_memory must swallow inner exceptions."""
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch(
|
||||
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||
return_value=("p {facts}", None),
|
||||
),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await audit_memory(db_session, PRO_USER_ID)
|
||||
|
||||
|
||||
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
|
||||
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||
|
||||
with _patch_audit(llm):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
|
||||
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
|
||||
db_session.add(keep)
|
||||
db_session.add(drop)
|
||||
await db_session.commit()
|
||||
|
||||
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
|
||||
llm = _mock_llm(deletion_payload)
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
remaining = result.scalars().all()
|
||||
remaining_ids = {r.id for r in remaining}
|
||||
assert keep.id in remaining_ids
|
||||
assert drop.id not in remaining_ids
|
||||
|
||||
|
||||
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
|
||||
|
||||
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm('"unexpected string"')
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
|
||||
|
||||
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
|
||||
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
|
||||
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
|
||||
db_session.add(row_a)
|
||||
db_session.add(row_b)
|
||||
db_session.add(row_c)
|
||||
await db_session.commit()
|
||||
|
||||
groups = json.dumps([
|
||||
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
|
||||
])
|
||||
llm = _mock_llm(groups)
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row_a)
|
||||
await db_session.refresh(row_b)
|
||||
await db_session.refresh(row_c)
|
||||
|
||||
assert row_a.subject_label == "Giulia"
|
||||
assert row_b.subject_label == "Giulia"
|
||||
assert row_c.object_label == "Giulia"
|
||||
assert row_c.subject_label == "Marco"
|
||||
|
||||
|
||||
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "giulia"
|
||||
|
||||
|
||||
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "Giulia"
|
||||
|
||||
|
||||
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_works_without_langfuse(db_session, pro_user):
|
||||
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
|
||||
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
|
||||
db_session.add(keep)
|
||||
db_session.add(drop)
|
||||
await db_session.commit()
|
||||
|
||||
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
|
||||
llm = _mock_llm(deletion_payload)
|
||||
|
||||
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
remaining_ids = {r.id for r in result.scalars().all()}
|
||||
assert keep.id in remaining_ids
|
||||
assert drop.id not in remaining_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
|
||||
await db_session.commit()
|
||||
|
||||
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
|
||||
llm = _mock_llm(groups)
|
||||
|
||||
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "Giulia"
|
||||
|
||||
|
||||
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
mock_get_prompt.assert_called_once()
|
||||
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
mock_get_prompt.assert_called_once()
|
||||
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"
|
||||
Reference in New Issue
Block a user