221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
"""Tests for Phase 3 — relational tier (Mem0g-light).
|
|
|
|
Coverage:
|
|
1. upsert_relation inserts a row and query_relations returns it
|
|
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
|
|
3. tier gating: Free user gets empty list from query_relations + enrich_context
|
|
4. enrich_context includes relational_memory key for Pro user
|
|
5. decay_relations decays confidence and prunes rows below threshold
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
|
|
from app.core.memory_maintenance import decay_relations
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import MemoryRelation, User
|
|
from tests.conftest import TEST_USER_IDS
|
|
|
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
|
FREE_USER_ID = TEST_USER_IDS["free"]
|
|
_FERNET_KEY = Fernet.generate_key().decode()
|
|
|
|
|
|
# ── DB override ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def pro_user_with_key(db_session):
|
|
"""Set encryption_key on the pro test user so Fernet works."""
|
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def free_user_with_key(db_session):
|
|
"""Set encryption_key on the free test user."""
|
|
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
|
|
"""upsert_relation inserts a row; query_relations returns it."""
|
|
mm = MemoryMiddleware(db_session)
|
|
await mm.upsert_relation(
|
|
PRO_USER_ID,
|
|
subject="Giulia",
|
|
subject_type="person",
|
|
predicate="works_at",
|
|
object_="Acme Corp",
|
|
object_type="company",
|
|
confidence=0.9,
|
|
)
|
|
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
|
|
assert len(rows) == 1
|
|
assert rows[0].subject_label == "Giulia"
|
|
assert rows[0].predicate == "works_at"
|
|
assert rows[0].object_label == "Acme Corp"
|
|
assert abs(rows[0].confidence - 0.9) < 0.001
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
|
|
"""Second upsert on same triple updates confidence and last_confirmed_at."""
|
|
mm = MemoryMiddleware(db_session)
|
|
await mm.upsert_relation(
|
|
PRO_USER_ID,
|
|
subject="Marco",
|
|
subject_type="person",
|
|
predicate="stakeholder_of",
|
|
object_="Project Nexus",
|
|
object_type="project",
|
|
confidence=0.7,
|
|
)
|
|
await mm.upsert_relation(
|
|
PRO_USER_ID,
|
|
subject="Marco",
|
|
subject_type="person",
|
|
predicate="stakeholder_of",
|
|
object_="Project Nexus",
|
|
object_type="project",
|
|
confidence=0.95,
|
|
)
|
|
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
|
|
# Only one row despite two upserts
|
|
assert len(rows) == 1
|
|
assert abs(rows[0].confidence - 0.95) < 0.001
|
|
assert rows[0].last_confirmed_at is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
|
|
"""Free user: upsert_relation is silently skipped (no row created)."""
|
|
mm = MemoryMiddleware(db_session)
|
|
await mm.upsert_relation(
|
|
FREE_USER_ID,
|
|
subject="Alice",
|
|
subject_type="person",
|
|
predicate="reports_to",
|
|
object_="Bob",
|
|
object_type="person",
|
|
confidence=0.8,
|
|
)
|
|
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
|
|
assert rows == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
|
|
"""enrich_context includes relational_memory key for Pro user."""
|
|
mm = MemoryMiddleware(db_session)
|
|
await mm.upsert_relation(
|
|
PRO_USER_ID,
|
|
subject="Elena",
|
|
subject_type="person",
|
|
predicate="cfo_of",
|
|
object_="StartupXYZ",
|
|
object_type="company",
|
|
confidence=0.85,
|
|
)
|
|
|
|
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
|
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
|
|
|
|
assert "relational_memory" in ctx
|
|
assert any("Elena" in r for r in ctx["relational_memory"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
|
|
"""Free user: relational_memory is empty list in enrich_context."""
|
|
mm = MemoryMiddleware(db_session)
|
|
|
|
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
|
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
|
|
|
|
assert ctx.get("relational_memory") == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
|
|
"""decay_relations reduces confidence on stale rows."""
|
|
old_date = datetime.now(timezone.utc) - timedelta(days=35)
|
|
row = MemoryRelation(
|
|
id=str(uuid.uuid4()),
|
|
user_id=PRO_USER_ID,
|
|
subject_label="OldContact",
|
|
subject_type="person",
|
|
predicate="knows",
|
|
object_label="SomeProject",
|
|
object_type="project",
|
|
confidence=0.8,
|
|
last_confirmed_at=old_date,
|
|
)
|
|
db_session.add(row)
|
|
await db_session.commit()
|
|
|
|
await decay_relations(db_session, PRO_USER_ID)
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
|
|
)
|
|
updated = result.scalar_one_or_none()
|
|
assert updated is not None
|
|
assert updated.confidence < 0.8
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
|
|
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
|
|
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
|
|
old_date = datetime.now(timezone.utc) - timedelta(days=65)
|
|
row = MemoryRelation(
|
|
id=str(uuid.uuid4()),
|
|
user_id=PRO_USER_ID,
|
|
subject_label="ExpiredContact",
|
|
subject_type="person",
|
|
predicate="used_to_work_with",
|
|
object_label="OldCorp",
|
|
object_type="company",
|
|
confidence=0.21,
|
|
last_confirmed_at=old_date,
|
|
)
|
|
db_session.add(row)
|
|
await db_session.commit()
|
|
|
|
await decay_relations(db_session, PRO_USER_ID)
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
|
|
)
|
|
pruned = result.scalar_one_or_none()
|
|
assert pruned is None
|