"""Tests for Phase 2 — Mem0-style Extract/Update pipeline. Coverage: 2.1 extract_candidates returns valid ExtractionResult with mocked LLM. 2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing). 2.3 run_extraction end-to-end with mocked LLM writes expected rows. 2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row. """ from __future__ import annotations import json import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from cryptography.fernet import Fernet from sqlalchemy import select from app.core.memory_extraction import ( ExtractionResult, MemoryCandidate, decide_action, extract_candidates, run_extraction, ) from app.core.memory_middleware import MemoryMiddleware from app.db import get_session from app.main import app from app.models import ExtractionQueue, MemoryCore, User from tests.conftest import TEST_USER_IDS PRO_USER_ID = TEST_USER_IDS["pro"] FREE_USER_ID = TEST_USER_IDS["free"] _FERNET_KEY = Fernet.generate_key().decode() # ── DB override ─────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def _override_db(db_session): async def _gen(): yield db_session app.dependency_overrides[get_session] = _gen yield app.dependency_overrides.pop(get_session, None) # ── Helpers ─────────────────────────────────────────────────────────────────── @pytest_asyncio.fixture async def pro_user(db_session): """Update the seeded pro user to have an encryption_key.""" result = await db_session.execute(select(User).where(User.id == PRO_USER_ID)) user = result.scalar_one() user.encryption_key = _FERNET_KEY await db_session.commit() return user @pytest_asyncio.fixture async def free_user(db_session): """Update the seeded free user to have an encryption_key.""" result = await db_session.execute(select(User).where(User.id == FREE_USER_ID)) user = result.scalar_one() user.encryption_key = _FERNET_KEY await db_session.commit() return user def _make_llm_response(content: str) -> MagicMock: msg = MagicMock() msg.content = content msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} return msg # ── TASK 2.1 — extract_candidates ──────────────────────────────────────────── @pytest.mark.asyncio async def test_extract_candidates_returns_valid_result(): payload = { "candidates": [ { "type": "fact", "content": "User's CFO is Giulia", "target_tier": "core", "subject": None, "predicate": None, "object": None, "confidence": 0.85, } ] } mock_response = _make_llm_response(json.dumps(payload)) with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, ): mock_prompt.return_value = ( "system prompt {last_turn} {core_memory} {recent_episodes}", None, ) llm_instance = MagicMock() llm_instance.bind.return_value = llm_instance llm_instance.ainvoke = AsyncMock(return_value=mock_response) mock_get_llm.return_value = llm_instance result = await extract_candidates( last_turn="User: My CFO is Giulia\nAssistant: Noted.", core_memory={}, recent_episodes=[], ) assert isinstance(result, ExtractionResult) assert len(result.candidates) == 1 assert result.candidates[0].type == "fact" assert "Giulia" in result.candidates[0].content assert result.candidates[0].confidence == 0.85 @pytest.mark.asyncio async def test_extract_candidates_returns_empty_on_llm_failure(): with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, ): mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None) llm_instance = MagicMock() llm_instance.bind.return_value = llm_instance llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) mock_get_llm.return_value = llm_instance result = await extract_candidates("turn", {}, []) assert isinstance(result, ExtractionResult) assert result.candidates == [] # ── TASK 2.2 — decide_action ───────────────────────────────────────────────── @pytest.mark.asyncio async def test_decide_action_add_when_no_existing(): candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core") action = await decide_action(candidate, existing=[]) assert action == "ADD" @pytest.mark.asyncio async def test_decide_action_noop(): candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core") mock_response = _make_llm_response("NOOP") with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, ): mock_prompt.return_value = ("p {candidate} {existing_memories}", None) llm_instance = MagicMock() llm_instance.ainvoke = AsyncMock(return_value=mock_response) mock_get_llm.return_value = llm_instance action = await decide_action(candidate, existing=["CFO is Giulia"]) assert action == "NOOP" @pytest.mark.asyncio async def test_decide_action_update(): candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core") mock_response = _make_llm_response("UPDATE") with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, ): mock_prompt.return_value = ("p {candidate} {existing_memories}", None) llm_instance = MagicMock() llm_instance.ainvoke = AsyncMock(return_value=mock_response) mock_get_llm.return_value = llm_instance action = await decide_action(candidate, existing=["CFO is Giulia"]) assert action == "UPDATE" @pytest.mark.asyncio async def test_decide_action_delete(): candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core") mock_response = _make_llm_response("DELETE") with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, ): mock_prompt.return_value = ("p {candidate} {existing_memories}", None) llm_instance = MagicMock() llm_instance.ainvoke = AsyncMock(return_value=mock_response) mock_get_llm.return_value = llm_instance action = await decide_action(candidate, existing=["CFO is Giulia"]) assert action == "DELETE" @pytest.mark.asyncio async def test_decide_action_defaults_add_on_llm_failure(): candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core") with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, ): mock_prompt.return_value = ("p {candidate} {existing_memories}", None) llm_instance = MagicMock() llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) mock_get_llm.return_value = llm_instance action = await decide_action(candidate, existing=["old memory"]) assert action == "ADD" # ── TASK 2.3 — run_extraction end-to-end ───────────────────────────────────── @pytest.mark.asyncio async def test_run_extraction_writes_core_candidate(db_session, pro_user): """'My CFO is Giulia' → fact candidate → core row written.""" fact_payload = { "candidates": [ { "type": "fact", "content": "User prefers morning meetings", "target_tier": "core", "confidence": 0.8, } ] } def _mock_llm_response(content: str): msg = MagicMock() msg.content = content msg.usage_metadata = {} return msg call_count = 0 async def _ainvoke_side_effect(messages): nonlocal call_count call_count += 1 if call_count == 1: # extract_candidates call return _mock_llm_response(json.dumps(fact_payload)) # decide_action — no existing → short-circuits to ADD without LLM return _mock_llm_response("ADD") with ( patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, patch("app.core.memory_extraction.get_langfuse", return_value=None), patch( "app.core.memory_extraction.get_prompt_or_fallback", side_effect=lambda name, fb: ( ("p {last_turn} {core_memory} {recent_episodes}", None) if name == "memory_extraction" else ("p {candidate} {existing_memories}", None) ), ), ): llm_instance = MagicMock() llm_instance.bind.return_value = llm_instance llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect) mock_get_llm.return_value = llm_instance await run_extraction( db=db_session, user_id=PRO_USER_ID, last_user_msg="My CFO is Giulia", last_assistant_msg="Noted, I will remember that.", session_id="test-session", ) # core row should exist result = await db_session.execute( select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID) ) rows = result.scalars().all() assert len(rows) >= 1 fernet = Fernet(_FERNET_KEY.encode()) values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows] assert any("morning meetings" in v for v in values) # ── TASK 2.4 — dispatch ─────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_dispatch_realtime_for_pro(db_session, pro_user): """Pro user: asyncio.create_task called (not queue row).""" middleware = MemoryMiddleware(db_session) with ( patch("app.core.memory_middleware.asyncio.create_task") as mock_task, patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True), ): await middleware._dispatch_extraction( user_id=PRO_USER_ID, episode_id=str(uuid.uuid4()), last_user_msg="hello", last_assistant_msg="hi", session_id=None, ) mock_task.assert_called_once() @pytest.mark.asyncio async def test_dispatch_queue_for_free(db_session, free_user): """Free user: ExtractionQueue row inserted.""" middleware = MemoryMiddleware(db_session) ep_id = str(uuid.uuid4()) with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False): await middleware._dispatch_extraction( user_id=FREE_USER_ID, episode_id=ep_id, last_user_msg="hello", last_assistant_msg="hi", session_id=None, ) result = await db_session.execute( select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID) ) rows = result.scalars().all() assert len(rows) == 1 assert rows[0].episode_id == ep_id