"""Tests for Phase 3 — relational tier (Mem0g-light). Coverage: 1. upsert_relation inserts a row and query_relations returns it 2. upsert_relation updates existing row on duplicate (subject/predicate/object) 3. tier gating: Free user gets empty list from query_relations + enrich_context 4. enrich_context includes relational_memory key for Pro user 5. decay_relations decays confidence and prunes rows below threshold """ from __future__ import annotations import uuid from datetime import datetime, timedelta, timezone from unittest.mock import patch import pytest import pytest_asyncio from cryptography.fernet import Fernet from sqlalchemy import select from app.core.memory_maintenance import decay_relations from app.core.memory_middleware import MemoryMiddleware from app.db import get_session from app.main import app from app.models import MemoryRelation, User from tests.conftest import TEST_USER_IDS PRO_USER_ID = TEST_USER_IDS["pro"] FREE_USER_ID = TEST_USER_IDS["free"] _FERNET_KEY = Fernet.generate_key().decode() # ── DB override ─────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def _override_db(db_session): async def _gen(): yield db_session app.dependency_overrides[get_session] = _gen yield app.dependency_overrides.pop(get_session, None) @pytest_asyncio.fixture async def pro_user_with_key(db_session): """Set encryption_key on the pro test user so Fernet works.""" result = await db_session.execute(select(User).where(User.id == PRO_USER_ID)) user = result.scalar_one() user.encryption_key = _FERNET_KEY await db_session.commit() return user @pytest_asyncio.fixture async def free_user_with_key(db_session): """Set encryption_key on the free test user.""" result = await db_session.execute(select(User).where(User.id == FREE_USER_ID)) user = result.scalar_one() user.encryption_key = _FERNET_KEY await db_session.commit() return user # ── Tests ───────────────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key): """upsert_relation inserts a row; query_relations returns it.""" mm = MemoryMiddleware(db_session) await mm.upsert_relation( PRO_USER_ID, subject="Giulia", subject_type="person", predicate="works_at", object_="Acme Corp", object_type="company", confidence=0.9, ) rows = await mm.query_relations(PRO_USER_ID, subject="Giulia") assert len(rows) == 1 assert rows[0].subject_label == "Giulia" assert rows[0].predicate == "works_at" assert rows[0].object_label == "Acme Corp" assert abs(rows[0].confidence - 0.9) < 0.001 @pytest.mark.asyncio async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key): """Second upsert on same triple updates confidence and last_confirmed_at.""" mm = MemoryMiddleware(db_session) await mm.upsert_relation( PRO_USER_ID, subject="Marco", subject_type="person", predicate="stakeholder_of", object_="Project Nexus", object_type="project", confidence=0.7, ) await mm.upsert_relation( PRO_USER_ID, subject="Marco", subject_type="person", predicate="stakeholder_of", object_="Project Nexus", object_type="project", confidence=0.95, ) rows = await mm.query_relations(PRO_USER_ID, subject="Marco") # Only one row despite two upserts assert len(rows) == 1 assert abs(rows[0].confidence - 0.95) < 0.001 assert rows[0].last_confirmed_at is not None @pytest.mark.asyncio async def test_free_tier_relation_skipped(db_session, free_user_with_key): """Free user: upsert_relation is silently skipped (no row created).""" mm = MemoryMiddleware(db_session) await mm.upsert_relation( FREE_USER_ID, subject="Alice", subject_type="person", predicate="reports_to", object_="Bob", object_type="person", confidence=0.8, ) rows = await mm.query_relations(FREE_USER_ID, subject="Alice") assert rows == [] @pytest.mark.asyncio async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key): """enrich_context includes relational_memory key for Pro user.""" mm = MemoryMiddleware(db_session) await mm.upsert_relation( PRO_USER_ID, subject="Elena", subject_type="person", predicate="cfo_of", object_="StartupXYZ", object_type="company", confidence=0.85, ) with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]): ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?") assert "relational_memory" in ctx assert any("Elena" in r for r in ctx["relational_memory"]) @pytest.mark.asyncio async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key): """Free user: relational_memory is empty list in enrich_context.""" mm = MemoryMiddleware(db_session) with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]): ctx = await mm.enrich_context(FREE_USER_ID, "test message") assert ctx.get("relational_memory") == [] @pytest.mark.asyncio async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key): """decay_relations reduces confidence on stale rows.""" old_date = datetime.now(timezone.utc) - timedelta(days=35) row = MemoryRelation( id=str(uuid.uuid4()), user_id=PRO_USER_ID, subject_label="OldContact", subject_type="person", predicate="knows", object_label="SomeProject", object_type="project", confidence=0.8, last_confirmed_at=old_date, ) db_session.add(row) await db_session.commit() await decay_relations(db_session, PRO_USER_ID) result = await db_session.execute( select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact") ) updated = result.scalar_one_or_none() assert updated is not None assert updated.confidence < 0.8 @pytest.mark.asyncio async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key): """decay_relations deletes rows whose confidence drops below 0.2 threshold.""" # Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned old_date = datetime.now(timezone.utc) - timedelta(days=65) row = MemoryRelation( id=str(uuid.uuid4()), user_id=PRO_USER_ID, subject_label="ExpiredContact", subject_type="person", predicate="used_to_work_with", object_label="OldCorp", object_type="company", confidence=0.21, last_confirmed_at=old_date, ) db_session.add(row) await db_session.commit() await decay_relations(db_session, PRO_USER_ID) result = await db_session.execute( select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact") ) pruned = result.scalar_one_or_none() assert pruned is None