Files
api/tests/test_memory_extraction.py
2026-04-16 17:57:49 +02:00

346 lines
12 KiB
Python

"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
Coverage:
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
"""
from __future__ import annotations
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_extraction import (
ExtractionResult,
MemoryCandidate,
decide_action,
extract_candidates,
run_extraction,
)
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import ExtractionQueue, MemoryCore, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
FREE_USER_ID = TEST_USER_IDS["free"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Helpers ───────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def pro_user(db_session):
"""Update the seeded pro user to have an encryption_key."""
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
@pytest_asyncio.fixture
async def free_user(db_session):
"""Update the seeded free user to have an encryption_key."""
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _make_llm_response(content: str) -> MagicMock:
msg = MagicMock()
msg.content = content
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
return msg
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_candidates_returns_valid_result():
payload = {
"candidates": [
{
"type": "fact",
"content": "User's CFO is Giulia",
"target_tier": "core",
"subject": None,
"predicate": None,
"object": None,
"confidence": 0.85,
}
]
}
mock_response = _make_llm_response(json.dumps(payload))
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = (
"system prompt {last_turn} {core_memory} {recent_episodes}",
None,
)
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
result = await extract_candidates(
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
core_memory={},
recent_episodes=[],
)
assert isinstance(result, ExtractionResult)
assert len(result.candidates) == 1
assert result.candidates[0].type == "fact"
assert "Giulia" in result.candidates[0].content
assert result.candidates[0].confidence == 0.85
@pytest.mark.asyncio
async def test_extract_candidates_returns_empty_on_llm_failure():
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
mock_get_llm.return_value = llm_instance
result = await extract_candidates("turn", {}, [])
assert isinstance(result, ExtractionResult)
assert result.candidates == []
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_decide_action_add_when_no_existing():
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
action = await decide_action(candidate, existing=[])
assert action == "ADD"
@pytest.mark.asyncio
async def test_decide_action_noop():
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
mock_response = _make_llm_response("NOOP")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "NOOP"
@pytest.mark.asyncio
async def test_decide_action_update():
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
mock_response = _make_llm_response("UPDATE")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "UPDATE"
@pytest.mark.asyncio
async def test_decide_action_delete():
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
mock_response = _make_llm_response("DELETE")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "DELETE"
@pytest.mark.asyncio
async def test_decide_action_defaults_add_on_llm_failure():
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["old memory"])
assert action == "ADD"
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
@pytest.mark.asyncio
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
"""'My CFO is Giulia' → fact candidate → core row written."""
fact_payload = {
"candidates": [
{
"type": "fact",
"content": "User prefers morning meetings",
"target_tier": "core",
"confidence": 0.8,
}
]
}
def _mock_llm_response(content: str):
msg = MagicMock()
msg.content = content
msg.usage_metadata = {}
return msg
call_count = 0
async def _ainvoke_side_effect(messages):
nonlocal call_count
call_count += 1
if call_count == 1:
# extract_candidates call
return _mock_llm_response(json.dumps(fact_payload))
# decide_action — no existing → short-circuits to ADD without LLM
return _mock_llm_response("ADD")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch(
"app.core.memory_extraction.get_prompt_or_fallback",
side_effect=lambda name, fb: (
("p {last_turn} {core_memory} {recent_episodes}", None)
if name == "memory_extraction"
else ("p {candidate} {existing_memories}", None)
),
),
):
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
mock_get_llm.return_value = llm_instance
await run_extraction(
db=db_session,
user_id=PRO_USER_ID,
last_user_msg="My CFO is Giulia",
last_assistant_msg="Noted, I will remember that.",
session_id="test-session",
)
# core row should exist
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
)
rows = result.scalars().all()
assert len(rows) >= 1
fernet = Fernet(_FERNET_KEY.encode())
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
assert any("morning meetings" in v for v in values)
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_dispatch_realtime_for_pro(db_session, pro_user):
"""Pro user: asyncio.create_task called (not queue row)."""
middleware = MemoryMiddleware(db_session)
with (
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
):
await middleware._dispatch_extraction(
user_id=PRO_USER_ID,
episode_id=str(uuid.uuid4()),
last_user_msg="hello",
last_assistant_msg="hi",
session_id=None,
)
mock_task.assert_called_once()
@pytest.mark.asyncio
async def test_dispatch_queue_for_free(db_session, free_user):
"""Free user: ExtractionQueue row inserted."""
middleware = MemoryMiddleware(db_session)
ep_id = str(uuid.uuid4())
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
await middleware._dispatch_extraction(
user_id=FREE_USER_ID,
episode_id=ep_id,
last_user_msg="hello",
last_assistant_msg="hi",
session_id=None,
)
result = await db_session.execute(
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
)
rows = result.scalars().all()
assert len(rows) == 1
assert rows[0].episode_id == ep_id