Files
api/tests/test_memory_relations.py
2026-04-17 17:04:27 +02:00

221 lines
7.4 KiB
Python

"""Tests for Phase 3 — relational tier (Mem0g-light).
Coverage:
1. upsert_relation inserts a row and query_relations returns it
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
3. tier gating: Free user gets empty list from query_relations + enrich_context
4. enrich_context includes relational_memory key for Pro user
5. decay_relations decays confidence and prunes rows below threshold
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_maintenance import decay_relations
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import MemoryRelation, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
FREE_USER_ID = TEST_USER_IDS["free"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest_asyncio.fixture
async def pro_user_with_key(db_session):
"""Set encryption_key on the pro test user so Fernet works."""
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
@pytest_asyncio.fixture
async def free_user_with_key(db_session):
"""Set encryption_key on the free test user."""
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
# ── Tests ─────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
"""upsert_relation inserts a row; query_relations returns it."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Giulia",
subject_type="person",
predicate="works_at",
object_="Acme Corp",
object_type="company",
confidence=0.9,
)
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
assert len(rows) == 1
assert rows[0].subject_label == "Giulia"
assert rows[0].predicate == "works_at"
assert rows[0].object_label == "Acme Corp"
assert abs(rows[0].confidence - 0.9) < 0.001
@pytest.mark.asyncio
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
"""Second upsert on same triple updates confidence and last_confirmed_at."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Marco",
subject_type="person",
predicate="stakeholder_of",
object_="Project Nexus",
object_type="project",
confidence=0.7,
)
await mm.upsert_relation(
PRO_USER_ID,
subject="Marco",
subject_type="person",
predicate="stakeholder_of",
object_="Project Nexus",
object_type="project",
confidence=0.95,
)
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
# Only one row despite two upserts
assert len(rows) == 1
assert abs(rows[0].confidence - 0.95) < 0.001
assert rows[0].last_confirmed_at is not None
@pytest.mark.asyncio
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
"""Free user: upsert_relation is silently skipped (no row created)."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
FREE_USER_ID,
subject="Alice",
subject_type="person",
predicate="reports_to",
object_="Bob",
object_type="person",
confidence=0.8,
)
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
assert rows == []
@pytest.mark.asyncio
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
"""enrich_context includes relational_memory key for Pro user."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Elena",
subject_type="person",
predicate="cfo_of",
object_="StartupXYZ",
object_type="company",
confidence=0.85,
)
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
assert "relational_memory" in ctx
assert any("Elena" in r for r in ctx["relational_memory"])
@pytest.mark.asyncio
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
"""Free user: relational_memory is empty list in enrich_context."""
mm = MemoryMiddleware(db_session)
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
assert ctx.get("relational_memory") == []
@pytest.mark.asyncio
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
"""decay_relations reduces confidence on stale rows."""
old_date = datetime.now(timezone.utc) - timedelta(days=35)
row = MemoryRelation(
id=str(uuid.uuid4()),
user_id=PRO_USER_ID,
subject_label="OldContact",
subject_type="person",
predicate="knows",
object_label="SomeProject",
object_type="project",
confidence=0.8,
last_confirmed_at=old_date,
)
db_session.add(row)
await db_session.commit()
await decay_relations(db_session, PRO_USER_ID)
result = await db_session.execute(
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
)
updated = result.scalar_one_or_none()
assert updated is not None
assert updated.confidence < 0.8
@pytest.mark.asyncio
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
old_date = datetime.now(timezone.utc) - timedelta(days=65)
row = MemoryRelation(
id=str(uuid.uuid4()),
user_id=PRO_USER_ID,
subject_label="ExpiredContact",
subject_type="person",
predicate="used_to_work_with",
object_label="OldCorp",
object_type="company",
confidence=0.21,
last_confirmed_at=old_date,
)
db_session.add(row)
await db_session.commit()
await decay_relations(db_session, PRO_USER_ID)
result = await db_session.execute(
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
)
pruned = result.scalar_one_or_none()
assert pruned is None