PHASE 3 — relational tier (Mem0g-light)
This commit is contained in:
220
tests/test_memory_relations.py
Normal file
220
tests/test_memory_relations.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for Phase 3 — relational tier (Mem0g-light).
|
||||
|
||||
Coverage:
|
||||
1. upsert_relation inserts a row and query_relations returns it
|
||||
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
|
||||
3. tier gating: Free user gets empty list from query_relations + enrich_context
|
||||
4. enrich_context includes relational_memory key for Pro user
|
||||
5. decay_relations decays confidence and prunes rows below threshold
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_maintenance import decay_relations
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryRelation, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user_with_key(db_session):
|
||||
"""Set encryption_key on the pro test user so Fernet works."""
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def free_user_with_key(db_session):
|
||||
"""Set encryption_key on the free test user."""
|
||||
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
|
||||
"""upsert_relation inserts a row; query_relations returns it."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Giulia",
|
||||
subject_type="person",
|
||||
predicate="works_at",
|
||||
object_="Acme Corp",
|
||||
object_type="company",
|
||||
confidence=0.9,
|
||||
)
|
||||
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
|
||||
assert len(rows) == 1
|
||||
assert rows[0].subject_label == "Giulia"
|
||||
assert rows[0].predicate == "works_at"
|
||||
assert rows[0].object_label == "Acme Corp"
|
||||
assert abs(rows[0].confidence - 0.9) < 0.001
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
|
||||
"""Second upsert on same triple updates confidence and last_confirmed_at."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Marco",
|
||||
subject_type="person",
|
||||
predicate="stakeholder_of",
|
||||
object_="Project Nexus",
|
||||
object_type="project",
|
||||
confidence=0.7,
|
||||
)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Marco",
|
||||
subject_type="person",
|
||||
predicate="stakeholder_of",
|
||||
object_="Project Nexus",
|
||||
object_type="project",
|
||||
confidence=0.95,
|
||||
)
|
||||
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
|
||||
# Only one row despite two upserts
|
||||
assert len(rows) == 1
|
||||
assert abs(rows[0].confidence - 0.95) < 0.001
|
||||
assert rows[0].last_confirmed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
|
||||
"""Free user: upsert_relation is silently skipped (no row created)."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
FREE_USER_ID,
|
||||
subject="Alice",
|
||||
subject_type="person",
|
||||
predicate="reports_to",
|
||||
object_="Bob",
|
||||
object_type="person",
|
||||
confidence=0.8,
|
||||
)
|
||||
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
|
||||
assert rows == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
|
||||
"""enrich_context includes relational_memory key for Pro user."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Elena",
|
||||
subject_type="person",
|
||||
predicate="cfo_of",
|
||||
object_="StartupXYZ",
|
||||
object_type="company",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
|
||||
|
||||
assert "relational_memory" in ctx
|
||||
assert any("Elena" in r for r in ctx["relational_memory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
|
||||
"""Free user: relational_memory is empty list in enrich_context."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
|
||||
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
|
||||
|
||||
assert ctx.get("relational_memory") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
|
||||
"""decay_relations reduces confidence on stale rows."""
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35)
|
||||
row = MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=PRO_USER_ID,
|
||||
subject_label="OldContact",
|
||||
subject_type="person",
|
||||
predicate="knows",
|
||||
object_label="SomeProject",
|
||||
object_type="project",
|
||||
confidence=0.8,
|
||||
last_confirmed_at=old_date,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
await decay_relations(db_session, PRO_USER_ID)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
|
||||
)
|
||||
updated = result.scalar_one_or_none()
|
||||
assert updated is not None
|
||||
assert updated.confidence < 0.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
|
||||
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
|
||||
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=65)
|
||||
row = MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=PRO_USER_ID,
|
||||
subject_label="ExpiredContact",
|
||||
subject_type="person",
|
||||
predicate="used_to_work_with",
|
||||
object_label="OldCorp",
|
||||
object_type="company",
|
||||
confidence=0.21,
|
||||
last_confirmed_at=old_date,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
await decay_relations(db_session, PRO_USER_ID)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
|
||||
)
|
||||
pruned = result.scalar_one_or_none()
|
||||
assert pruned is None
|
||||
Reference in New Issue
Block a user