PHASE 2 — Mem0-style Extract/Update pipeline

This commit is contained in:
Roberto Musso
2026-04-16 17:57:49 +02:00
parent 2d8abb6311
commit 741b9b87fb
9 changed files with 949 additions and 4 deletions

View File

@@ -53,6 +53,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT=
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
LLM_MODEL_MEMORY_EXTRACTOR=
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=

View File

@@ -0,0 +1,38 @@
"""add extraction_queue
Revision ID: 1f5975a4f3f4
Revises: 005
Create Date: 2026-04-16 17:26:25.790870
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1f5975a4f3f4'
down_revision: Union[str, None] = '005'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'extraction_queue',
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
op.drop_table('extraction_queue')

View File

@@ -26,6 +26,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"batch_builder": False,
"sso": False,
"real_embeddings": False, # keyword fallback only
"realtime_extraction": False, # batch queue (Phase 2)
},
"pro": {
"agents": -1, # unlimited
@@ -35,6 +36,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"batch_builder": False,
"sso": False,
"real_embeddings": True, # pgvector cosine search
"realtime_extraction": True, # fire-and-forget asyncio.create_task
},
"power": {
"agents": -1,
@@ -44,6 +46,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"batch_builder": True,
"sso": False,
"real_embeddings": True,
"realtime_extraction": True,
},
"team": {
"agents": -1,
@@ -53,6 +56,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"batch_builder": True,
"sso": True,
"real_embeddings": True,
"realtime_extraction": True,
},
}

View File

@@ -27,6 +27,7 @@ class Settings(BaseSettings):
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).

View File

@@ -103,6 +103,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
}

View File

@@ -0,0 +1,456 @@
"""Mem0-style Extract/Update pipeline — Phase 2.
Runs after every ``store_episode`` call to distil durable facts, preferences,
routines, and relations from the latest conversation turn.
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
Design notes
------------
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
- Zero-trust: never logs decrypted user content; relation subject/object labels are
treated as identifiers (safe to log per spec).
- Must not raise into the request path — caller wraps in asyncio.create_task().
"""
from __future__ import annotations
import json
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
logger = logging.getLogger(__name__)
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
_EXTRACTION_FALLBACK = (
"You are a memory extractor for a personal AI secretary. Given the last conversation "
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
"preferences, routines, and person/project relations worth remembering.\n\n"
"Output JSON matching this schema exactly:\n"
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
'"content": "<short canonical statement>", '
'"target_tier": "<core|associative|relational|proactive>", '
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
"Rules:\n"
"- Skip small talk, greetings, one-off questions.\n"
"- Max 5 candidates per call.\n"
"- Only extract durable information (still true next week).\n"
"- For type=relation: subject/predicate/object required.\n"
"- Default confidence=0.7.\n\n"
"## Last turn\n{last_turn}\n\n"
"## Core memory (current)\n{core_memory}\n\n"
"## Recent episodes\n{recent_episodes}"
)
_DECIDE_FALLBACK = (
"You are a memory update decision engine. Given a new memory candidate and a list of "
"existing memories from the same tier, decide what action to take.\n\n"
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
"- ADD: new information not in existing memories.\n"
"- UPDATE: contradicts or supersedes an existing memory.\n"
"- DELETE: states something is no longer true.\n"
"- NOOP: already captured accurately.\n\n"
"## New candidate\n{candidate}\n\n"
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
)
# ── Pydantic schemas ───────────────────────────────────────────────────────────
class MemoryCandidate(BaseModel):
type: Literal["fact", "preference", "relation", "routine"]
content: str
target_tier: Literal["core", "associative", "relational", "proactive"]
subject: str | None = None
predicate: str | None = None
object: str | None = None
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
class ExtractionResult(BaseModel):
candidates: list[MemoryCandidate] = Field(default_factory=list)
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
async def extract_candidates(
last_turn: str,
core_memory: dict[str, str],
recent_episodes: list[str],
) -> ExtractionResult:
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
Returns an ExtractionResult (may be empty on failure — never raises).
"""
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
if isinstance(system_text, list):
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
except Exception as exc:
logger.warning("memory_extraction: compile failed: %s", exc)
system_text = template.format(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
else:
system_text = template.format(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
llm = get_agent_llm("memory-extractor", temperature=0)
# Bind JSON mode so the model always returns parseable output.
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
lf = get_langfuse()
try:
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Extract memory candidates as JSON."),
]
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-extraction",
model=model_for_agent("memory-extractor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm_json.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm_json.ainvoke(messages)
raw = json.loads(response.content)
result = ExtractionResult.model_validate(raw)
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
return result
except Exception as exc:
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
return ExtractionResult(candidates=[])
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
async def decide_action(
candidate: MemoryCandidate,
existing: list[str],
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
"""Decide what to do with a candidate given existing memories in the same tier.
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
Never raises.
"""
if not existing:
return "ADD"
candidate_str = f"[{candidate.type}] {candidate.content}"
existing_str = "\n".join(f"- {m}" for m in existing)
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(
candidate=candidate_str,
existing_memories=existing_str,
)
if isinstance(system_text, list):
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
except Exception as exc:
logger.warning("memory_extraction: decide compile failed: %s", exc)
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
else:
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
llm = get_agent_llm("memory-extractor", temperature=0)
lf = get_langfuse()
try:
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Decide action."),
]
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-decide-action",
model=model_for_agent("memory-extractor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
verb = response.content.strip().upper()
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
return verb # type: ignore[return-value]
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
return "ADD"
except Exception as exc:
logger.warning("memory_extraction: decide_action failed: %s", exc)
return "ADD"
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
async def run_extraction(
db: AsyncSession,
user_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
"""Full Mem0-style extract/update pipeline for one conversation turn.
Steps:
1. Load core memory + last 5 episodes.
2. extract_candidates() → up to 5 MemoryCandidate objects.
3. For each candidate: find top-3 neighbours → decide_action() → apply.
4. Trace via Langfuse.
Never raises — wraps everything in try/except.
"""
try:
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
except Exception as exc:
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
async def _run_extraction_inner(
db: AsyncSession,
user_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
middleware = MemoryMiddleware(db)
fernet = await middleware._get_fernet(user_id)
if fernet is None:
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
return
# 1. Load context
core: dict[str, str] = await middleware._load_core(user_id, fernet)
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
lf = get_langfuse()
async def _run(trace_id: str | None) -> dict[str, Any]:
# 2. Extract candidates
result = await extract_candidates(last_turn, core, episodes)
if not result.candidates:
logger.info("memory_extraction: no candidates user=%s", user_id)
return {"candidates": 0, "applied": 0}
logger.info(
"memory_extraction: processing %d candidates user=%s trace=%s",
len(result.candidates),
user_id,
trace_id or "-",
)
# 3. Apply each candidate
applied = 0
actions: list[str] = []
for candidate in result.candidates:
try:
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
applied += 1
actions.append(f"{candidate.type}:{candidate.target_tier}")
except Exception as exc:
logger.warning(
"memory_extraction: apply failed candidate=%r user=%s: %s",
candidate.content[:80],
user_id,
exc,
)
logger.info(
"memory_extraction: applied %d/%d candidates user=%s",
applied,
len(result.candidates),
user_id,
)
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
with langfuse_context(user_id=user_id, session_id=session_id):
if lf:
with lf.start_as_current_observation(
as_type="span",
name="memory-extraction-pipeline",
input={"last_turn_preview": last_turn[:200]},
) as span:
summary = await _run(trace_id=span.id)
span.update(output=summary)
try:
lf.flush()
except Exception:
pass
else:
await _run(trace_id=None)
async def _apply_candidate(
middleware: Any,
db: AsyncSession,
user_id: str,
fernet: Any,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Fetch neighbours, decide action, apply to the appropriate tier."""
neighbours: list[str] = []
if candidate.target_tier == "core":
# For core tier: neighbours are existing core block values for similar keys.
blocks = await middleware.list_core_blocks(user_id)
neighbours = [b["value"] for b in blocks[:3]]
elif candidate.target_tier == "associative":
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
elif candidate.target_tier == "relational":
# Relation candidates handled specially — passed to upsert_relation directly.
# Neighbours: search by subject label if available.
neighbours = []
elif candidate.target_tier == "proactive":
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
action = await decide_action(candidate, neighbours)
logger.info(
"memory_extraction: candidate type=%s tier=%s action=%s",
candidate.type,
candidate.target_tier,
action,
)
if action == "NOOP":
return
if candidate.target_tier == "relational":
# Always upsert relations — decide_action skipped (no neighbour search).
if candidate.subject and candidate.predicate and candidate.object:
await _upsert_relation_stub(
middleware, db, user_id, candidate, trace_id
)
return
if action in ("ADD", "UPDATE"):
if candidate.target_tier == "core":
# Derive a short key from the content (first 40 chars, snake_cased).
key = _content_to_key(candidate.content)
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
elif candidate.target_tier == "associative":
await middleware.store_associative(user_id, candidate.content)
elif candidate.target_tier == "proactive":
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
elif action == "DELETE":
if candidate.target_tier == "core":
key = _content_to_key(candidate.content)
await middleware.delete_core(user_id, key)
def _content_to_key(content: str) -> str:
"""Derive a short snake_case key from a content string (first 40 chars)."""
import re # noqa: PLC0415
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
return slug or "memory"
async def _upsert_relation_stub(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Stub: upsert_relation will be fully wired in Phase 3.
Called here so Phase 2 extraction pipeline already routes relation candidates
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
"""
if hasattr(middleware, "upsert_relation"):
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject,
subject_type="unknown",
predicate=candidate.predicate,
object_=candidate.object,
object_type="unknown",
confidence=candidate.confidence,
)
else:
logger.info(
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
async def _store_proactive_stub(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
fernet: Any,
) -> None:
"""Store a proactive pattern row directly (MemoryProactive model)."""
import uuid # noqa: PLC0415
from app.models import MemoryProactive # noqa: PLC0415
from app.core.memory_middleware import _encrypt # noqa: PLC0415
encrypted = _encrypt(fernet, candidate.content)
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=user_id,
pattern_encrypted=encrypted,
confidence=candidate.confidence,
source="inferred",
)
db.add(row)
try:
await db.commit()
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
except Exception as exc:
logger.warning("memory_extraction: store proactive failed: %s", exc)
await db.rollback()

View File

@@ -18,6 +18,7 @@ Usage:
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Any
@@ -27,6 +28,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import (
ExtractionQueue,
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
@@ -106,7 +108,10 @@ class MemoryMiddleware:
"""Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. Full LLM summarisation can be added in a later step.
latency low. After committing the episode row, dispatches the Mem0-style
extraction pipeline:
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
- Free → enqueue an ExtractionQueue row for the daily cron.
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
@@ -115,26 +120,95 @@ class MemoryMiddleware:
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic(
episode = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(row)
self._db.add(episode)
episode_id: str = episode.id
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
tier = user_dbg.get("tier") or "free"
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
tier,
session_id,
)
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
return
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
await self._dispatch_extraction(
user_id=user_id,
episode_id=episode_id,
last_user_msg=message,
last_assistant_msg=response,
session_id=session_id,
)
async def _dispatch_extraction(
self,
user_id: str,
episode_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
"""Route extraction to realtime task or batch queue based on user tier."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
tier = await tier_manager.get_tier(user_id, self._db)
if tier_manager.check_feature(tier, "realtime_extraction"):
# Pro/Power/Team: fire-and-forget in the background.
# Must open a fresh session — request session closes after handler returns.
from app.core.memory_extraction import run_extraction # noqa: PLC0415
from app.db import async_session # noqa: PLC0415
async def _task() -> None:
try:
async with async_session() as fresh_db:
await run_extraction(
db=fresh_db,
user_id=user_id,
last_user_msg=last_user_msg,
last_assistant_msg=last_assistant_msg,
session_id=session_id,
)
except Exception as exc:
logger.warning(
"memory: extraction task failed user=%s: %s", user_id, exc
)
asyncio.create_task(_task())
logger.info("memory: realtime extraction dispatched user=%s", user_id)
else:
# Free tier: enqueue for daily batch cron.
queue_row = ExtractionQueue(
id=str(uuid.uuid4()),
user_id=user_id,
episode_id=episode_id,
)
self._db.add(queue_row)
try:
await self._db.commit()
logger.info(
"memory: extraction enqueued (batch) user=%s episode=%s",
user_id,
episode_id,
)
except Exception as exc:
logger.warning(
"memory: extraction queue insert failed user=%s: %s", user_id, exc
)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user."""

View File

@@ -351,6 +351,28 @@ class MemoryProactive(Base):
)
class ExtractionQueue(Base):
"""Batch extraction queue for Free-tier users (Phase 2).
Pro/Power/Team users get realtime asyncio.create_task() extraction.
Free users get a queue row here; a daily cron (Phase 5) drains it.
"""
__tablename__ = "extraction_queue"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
episode_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
"""Plugin marketplace catalog entry."""

View File

@@ -0,0 +1,345 @@
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
Coverage:
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
"""
from __future__ import annotations
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_extraction import (
ExtractionResult,
MemoryCandidate,
decide_action,
extract_candidates,
run_extraction,
)
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import ExtractionQueue, MemoryCore, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
FREE_USER_ID = TEST_USER_IDS["free"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Helpers ───────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def pro_user(db_session):
"""Update the seeded pro user to have an encryption_key."""
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
@pytest_asyncio.fixture
async def free_user(db_session):
"""Update the seeded free user to have an encryption_key."""
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _make_llm_response(content: str) -> MagicMock:
msg = MagicMock()
msg.content = content
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
return msg
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_candidates_returns_valid_result():
payload = {
"candidates": [
{
"type": "fact",
"content": "User's CFO is Giulia",
"target_tier": "core",
"subject": None,
"predicate": None,
"object": None,
"confidence": 0.85,
}
]
}
mock_response = _make_llm_response(json.dumps(payload))
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = (
"system prompt {last_turn} {core_memory} {recent_episodes}",
None,
)
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
result = await extract_candidates(
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
core_memory={},
recent_episodes=[],
)
assert isinstance(result, ExtractionResult)
assert len(result.candidates) == 1
assert result.candidates[0].type == "fact"
assert "Giulia" in result.candidates[0].content
assert result.candidates[0].confidence == 0.85
@pytest.mark.asyncio
async def test_extract_candidates_returns_empty_on_llm_failure():
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
mock_get_llm.return_value = llm_instance
result = await extract_candidates("turn", {}, [])
assert isinstance(result, ExtractionResult)
assert result.candidates == []
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_decide_action_add_when_no_existing():
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
action = await decide_action(candidate, existing=[])
assert action == "ADD"
@pytest.mark.asyncio
async def test_decide_action_noop():
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
mock_response = _make_llm_response("NOOP")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "NOOP"
@pytest.mark.asyncio
async def test_decide_action_update():
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
mock_response = _make_llm_response("UPDATE")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "UPDATE"
@pytest.mark.asyncio
async def test_decide_action_delete():
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
mock_response = _make_llm_response("DELETE")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "DELETE"
@pytest.mark.asyncio
async def test_decide_action_defaults_add_on_llm_failure():
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["old memory"])
assert action == "ADD"
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
@pytest.mark.asyncio
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
"""'My CFO is Giulia' → fact candidate → core row written."""
fact_payload = {
"candidates": [
{
"type": "fact",
"content": "User prefers morning meetings",
"target_tier": "core",
"confidence": 0.8,
}
]
}
def _mock_llm_response(content: str):
msg = MagicMock()
msg.content = content
msg.usage_metadata = {}
return msg
call_count = 0
async def _ainvoke_side_effect(messages):
nonlocal call_count
call_count += 1
if call_count == 1:
# extract_candidates call
return _mock_llm_response(json.dumps(fact_payload))
# decide_action — no existing → short-circuits to ADD without LLM
return _mock_llm_response("ADD")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch(
"app.core.memory_extraction.get_prompt_or_fallback",
side_effect=lambda name, fb: (
("p {last_turn} {core_memory} {recent_episodes}", None)
if name == "memory_extraction"
else ("p {candidate} {existing_memories}", None)
),
),
):
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
mock_get_llm.return_value = llm_instance
await run_extraction(
db=db_session,
user_id=PRO_USER_ID,
last_user_msg="My CFO is Giulia",
last_assistant_msg="Noted, I will remember that.",
session_id="test-session",
)
# core row should exist
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
)
rows = result.scalars().all()
assert len(rows) >= 1
fernet = Fernet(_FERNET_KEY.encode())
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
assert any("morning meetings" in v for v in values)
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_dispatch_realtime_for_pro(db_session, pro_user):
"""Pro user: asyncio.create_task called (not queue row)."""
middleware = MemoryMiddleware(db_session)
with (
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
):
await middleware._dispatch_extraction(
user_id=PRO_USER_ID,
episode_id=str(uuid.uuid4()),
last_user_msg="hello",
last_assistant_msg="hi",
session_id=None,
)
mock_task.assert_called_once()
@pytest.mark.asyncio
async def test_dispatch_queue_for_free(db_session, free_user):
"""Free user: ExtractionQueue row inserted."""
middleware = MemoryMiddleware(db_session)
ep_id = str(uuid.uuid4())
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
await middleware._dispatch_extraction(
user_id=FREE_USER_ID,
episode_id=ep_id,
last_user_msg="hello",
last_assistant_msg="hi",
session_id=None,
)
result = await db_session.execute(
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
)
rows = result.scalars().all()
assert len(rows) == 1
assert rows[0].episode_id == ep_id