346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
|
|
|
|
Coverage:
|
|
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
|
|
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
|
|
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
|
|
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
|
|
from app.core.memory_extraction import (
|
|
ExtractionResult,
|
|
MemoryCandidate,
|
|
decide_action,
|
|
extract_candidates,
|
|
run_extraction,
|
|
)
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import ExtractionQueue, MemoryCore, User
|
|
from tests.conftest import TEST_USER_IDS
|
|
|
|
|
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
|
FREE_USER_ID = TEST_USER_IDS["free"]
|
|
_FERNET_KEY = Fernet.generate_key().decode()
|
|
|
|
|
|
# ── DB override ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture
|
|
async def pro_user(db_session):
|
|
"""Update the seeded pro user to have an encryption_key."""
|
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def free_user(db_session):
|
|
"""Update the seeded free user to have an encryption_key."""
|
|
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
def _make_llm_response(content: str) -> MagicMock:
|
|
msg = MagicMock()
|
|
msg.content = content
|
|
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
return msg
|
|
|
|
|
|
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_candidates_returns_valid_result():
|
|
payload = {
|
|
"candidates": [
|
|
{
|
|
"type": "fact",
|
|
"content": "User's CFO is Giulia",
|
|
"target_tier": "core",
|
|
"subject": None,
|
|
"predicate": None,
|
|
"object": None,
|
|
"confidence": 0.85,
|
|
}
|
|
]
|
|
}
|
|
mock_response = _make_llm_response(json.dumps(payload))
|
|
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
|
):
|
|
mock_prompt.return_value = (
|
|
"system prompt {last_turn} {core_memory} {recent_episodes}",
|
|
None,
|
|
)
|
|
llm_instance = MagicMock()
|
|
llm_instance.bind.return_value = llm_instance
|
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
result = await extract_candidates(
|
|
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
|
|
core_memory={},
|
|
recent_episodes=[],
|
|
)
|
|
|
|
assert isinstance(result, ExtractionResult)
|
|
assert len(result.candidates) == 1
|
|
assert result.candidates[0].type == "fact"
|
|
assert "Giulia" in result.candidates[0].content
|
|
assert result.candidates[0].confidence == 0.85
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_candidates_returns_empty_on_llm_failure():
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
|
):
|
|
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
|
|
llm_instance = MagicMock()
|
|
llm_instance.bind.return_value = llm_instance
|
|
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
result = await extract_candidates("turn", {}, [])
|
|
|
|
assert isinstance(result, ExtractionResult)
|
|
assert result.candidates == []
|
|
|
|
|
|
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decide_action_add_when_no_existing():
|
|
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
|
action = await decide_action(candidate, existing=[])
|
|
assert action == "ADD"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decide_action_noop():
|
|
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
|
mock_response = _make_llm_response("NOOP")
|
|
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
|
):
|
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
|
llm_instance = MagicMock()
|
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
|
|
|
assert action == "NOOP"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decide_action_update():
|
|
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
|
mock_response = _make_llm_response("UPDATE")
|
|
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
|
):
|
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
|
llm_instance = MagicMock()
|
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
|
|
|
assert action == "UPDATE"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decide_action_delete():
|
|
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
|
|
mock_response = _make_llm_response("DELETE")
|
|
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
|
):
|
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
|
llm_instance = MagicMock()
|
|
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
|
|
|
assert action == "DELETE"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decide_action_defaults_add_on_llm_failure():
|
|
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
|
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
|
):
|
|
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
|
llm_instance = MagicMock()
|
|
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
action = await decide_action(candidate, existing=["old memory"])
|
|
|
|
assert action == "ADD"
|
|
|
|
|
|
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
|
|
"""'My CFO is Giulia' → fact candidate → core row written."""
|
|
fact_payload = {
|
|
"candidates": [
|
|
{
|
|
"type": "fact",
|
|
"content": "User prefers morning meetings",
|
|
"target_tier": "core",
|
|
"confidence": 0.8,
|
|
}
|
|
]
|
|
}
|
|
|
|
def _mock_llm_response(content: str):
|
|
msg = MagicMock()
|
|
msg.content = content
|
|
msg.usage_metadata = {}
|
|
return msg
|
|
|
|
call_count = 0
|
|
|
|
async def _ainvoke_side_effect(messages):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
# extract_candidates call
|
|
return _mock_llm_response(json.dumps(fact_payload))
|
|
# decide_action — no existing → short-circuits to ADD without LLM
|
|
return _mock_llm_response("ADD")
|
|
|
|
with (
|
|
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
|
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
|
patch(
|
|
"app.core.memory_extraction.get_prompt_or_fallback",
|
|
side_effect=lambda name, fb: (
|
|
("p {last_turn} {core_memory} {recent_episodes}", None)
|
|
if name == "memory_extraction"
|
|
else ("p {candidate} {existing_memories}", None)
|
|
),
|
|
),
|
|
):
|
|
llm_instance = MagicMock()
|
|
llm_instance.bind.return_value = llm_instance
|
|
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
|
|
mock_get_llm.return_value = llm_instance
|
|
|
|
await run_extraction(
|
|
db=db_session,
|
|
user_id=PRO_USER_ID,
|
|
last_user_msg="My CFO is Giulia",
|
|
last_assistant_msg="Noted, I will remember that.",
|
|
session_id="test-session",
|
|
)
|
|
|
|
# core row should exist
|
|
result = await db_session.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
|
|
)
|
|
rows = result.scalars().all()
|
|
assert len(rows) >= 1
|
|
fernet = Fernet(_FERNET_KEY.encode())
|
|
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
|
|
assert any("morning meetings" in v for v in values)
|
|
|
|
|
|
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatch_realtime_for_pro(db_session, pro_user):
|
|
"""Pro user: asyncio.create_task called (not queue row)."""
|
|
middleware = MemoryMiddleware(db_session)
|
|
|
|
with (
|
|
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
|
|
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
|
|
):
|
|
await middleware._dispatch_extraction(
|
|
user_id=PRO_USER_ID,
|
|
episode_id=str(uuid.uuid4()),
|
|
last_user_msg="hello",
|
|
last_assistant_msg="hi",
|
|
session_id=None,
|
|
)
|
|
|
|
mock_task.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dispatch_queue_for_free(db_session, free_user):
|
|
"""Free user: ExtractionQueue row inserted."""
|
|
middleware = MemoryMiddleware(db_session)
|
|
ep_id = str(uuid.uuid4())
|
|
|
|
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
|
|
await middleware._dispatch_extraction(
|
|
user_id=FREE_USER_ID,
|
|
episode_id=ep_id,
|
|
last_user_msg="hello",
|
|
last_assistant_msg="hi",
|
|
session_id=None,
|
|
)
|
|
|
|
result = await db_session.execute(
|
|
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
|
|
)
|
|
rows = result.scalars().all()
|
|
assert len(rows) == 1
|
|
assert rows[0].episode_id == ep_id
|