PHASE 2 — Mem0-style Extract/Update pipeline
This commit is contained in:
345
tests/test_memory_extraction.py
Normal file
345
tests/test_memory_extraction.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
|
||||
|
||||
Coverage:
|
||||
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
|
||||
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
|
||||
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
|
||||
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_extraction import (
|
||||
ExtractionResult,
|
||||
MemoryCandidate,
|
||||
decide_action,
|
||||
extract_candidates,
|
||||
run_extraction,
|
||||
)
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import ExtractionQueue, MemoryCore, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user(db_session):
|
||||
"""Update the seeded pro user to have an encryption_key."""
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def free_user(db_session):
|
||||
"""Update the seeded free user to have an encryption_key."""
|
||||
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _make_llm_response(content: str) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
return msg
|
||||
|
||||
|
||||
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_candidates_returns_valid_result():
|
||||
payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"type": "fact",
|
||||
"content": "User's CFO is Giulia",
|
||||
"target_tier": "core",
|
||||
"subject": None,
|
||||
"predicate": None,
|
||||
"object": None,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_response = _make_llm_response(json.dumps(payload))
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = (
|
||||
"system prompt {last_turn} {core_memory} {recent_episodes}",
|
||||
None,
|
||||
)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
result = await extract_candidates(
|
||||
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
|
||||
core_memory={},
|
||||
recent_episodes=[],
|
||||
)
|
||||
|
||||
assert isinstance(result, ExtractionResult)
|
||||
assert len(result.candidates) == 1
|
||||
assert result.candidates[0].type == "fact"
|
||||
assert "Giulia" in result.candidates[0].content
|
||||
assert result.candidates[0].confidence == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_candidates_returns_empty_on_llm_failure():
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
result = await extract_candidates("turn", {}, [])
|
||||
|
||||
assert isinstance(result, ExtractionResult)
|
||||
assert result.candidates == []
|
||||
|
||||
|
||||
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_add_when_no_existing():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||
action = await decide_action(candidate, existing=[])
|
||||
assert action == "ADD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_noop():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||
mock_response = _make_llm_response("NOOP")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "NOOP"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_update():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||
mock_response = _make_llm_response("UPDATE")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "UPDATE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_delete():
|
||||
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
|
||||
mock_response = _make_llm_response("DELETE")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "DELETE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_defaults_add_on_llm_failure():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["old memory"])
|
||||
|
||||
assert action == "ADD"
|
||||
|
||||
|
||||
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
|
||||
"""'My CFO is Giulia' → fact candidate → core row written."""
|
||||
fact_payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"type": "fact",
|
||||
"content": "User prefers morning meetings",
|
||||
"target_tier": "core",
|
||||
"confidence": 0.8,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def _mock_llm_response(content: str):
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {}
|
||||
return msg
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _ainvoke_side_effect(messages):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# extract_candidates call
|
||||
return _mock_llm_response(json.dumps(fact_payload))
|
||||
# decide_action — no existing → short-circuits to ADD without LLM
|
||||
return _mock_llm_response("ADD")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch(
|
||||
"app.core.memory_extraction.get_prompt_or_fallback",
|
||||
side_effect=lambda name, fb: (
|
||||
("p {last_turn} {core_memory} {recent_episodes}", None)
|
||||
if name == "memory_extraction"
|
||||
else ("p {candidate} {existing_memories}", None)
|
||||
),
|
||||
),
|
||||
):
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
await run_extraction(
|
||||
db=db_session,
|
||||
user_id=PRO_USER_ID,
|
||||
last_user_msg="My CFO is Giulia",
|
||||
last_assistant_msg="Noted, I will remember that.",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
# core row should exist
|
||||
result = await db_session.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) >= 1
|
||||
fernet = Fernet(_FERNET_KEY.encode())
|
||||
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
|
||||
assert any("morning meetings" in v for v in values)
|
||||
|
||||
|
||||
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_realtime_for_pro(db_session, pro_user):
|
||||
"""Pro user: asyncio.create_task called (not queue row)."""
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
|
||||
with (
|
||||
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
|
||||
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
|
||||
):
|
||||
await middleware._dispatch_extraction(
|
||||
user_id=PRO_USER_ID,
|
||||
episode_id=str(uuid.uuid4()),
|
||||
last_user_msg="hello",
|
||||
last_assistant_msg="hi",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
mock_task.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_queue_for_free(db_session, free_user):
|
||||
"""Free user: ExtractionQueue row inserted."""
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ep_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
|
||||
await middleware._dispatch_extraction(
|
||||
user_id=FREE_USER_ID,
|
||||
episode_id=ep_id,
|
||||
last_user_msg="hello",
|
||||
last_assistant_msg="hi",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].episode_id == ep_id
|
||||
Reference in New Issue
Block a user