diff --git a/.env.example b/.env.example index 40e18c4..37f41a7 100644 --- a/.env.example +++ b/.env.example @@ -53,6 +53,10 @@ LLM_MODEL_CLOUD_PROCESSOR= # Setup-agent — guided journey to build an AgentConfig via WebSocket chat. LLM_MODEL_SETUP_AGENT= +# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2). +# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0). +LLM_MODEL_MEMORY_EXTRACTOR= + # ── Stripe (leave empty to stub billing) ────────────────────────────────────── STRIPE_SECRET_KEY= STRIPE_WEBHOOK_SECRET= diff --git a/alembic/versions/1f5975a4f3f4_add_extraction_queue.py b/alembic/versions/1f5975a4f3f4_add_extraction_queue.py new file mode 100644 index 0000000..e7e41ec --- /dev/null +++ b/alembic/versions/1f5975a4f3f4_add_extraction_queue.py @@ -0,0 +1,38 @@ +"""add extraction_queue + +Revision ID: 1f5975a4f3f4 +Revises: 005 +Create Date: 2026-04-16 17:26:25.790870 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1f5975a4f3f4' +down_revision: Union[str, None] = '005' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'extraction_queue', + sa.Column('id', sa.Uuid(as_uuid=False), nullable=False), + sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False), + sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False) + + +def downgrade() -> None: + op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue') + op.drop_table('extraction_queue') diff --git a/app/billing/tier_manager.py b/app/billing/tier_manager.py index 4a523c4..859d378 100644 --- a/app/billing/tier_manager.py +++ b/app/billing/tier_manager.py @@ -26,6 +26,7 @@ FEATURES: dict[str, dict[str, Any]] = { "batch_builder": False, "sso": False, "real_embeddings": False, # keyword fallback only + "realtime_extraction": False, # batch queue (Phase 2) }, "pro": { "agents": -1, # unlimited @@ -35,6 +36,7 @@ FEATURES: dict[str, dict[str, Any]] = { "batch_builder": False, "sso": False, "real_embeddings": True, # pgvector cosine search + "realtime_extraction": True, # fire-and-forget asyncio.create_task }, "power": { "agents": -1, @@ -44,6 +46,7 @@ FEATURES: dict[str, dict[str, Any]] = { "batch_builder": True, "sso": False, "real_embeddings": True, + "realtime_extraction": True, }, "team": { "agents": -1, @@ -53,6 +56,7 @@ FEATURES: dict[str, dict[str, Any]] = { "batch_builder": True, "sso": True, "real_embeddings": True, + "realtime_extraction": True, }, } diff --git a/app/config/settings.py b/app/config/settings.py index 6466dce..7dcb716 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -27,6 +27,7 @@ class Settings(BaseSettings): LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner) LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner) LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey + LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide) # GitHub Copilot OAuth token storage directory. # Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot). diff --git a/app/core/llm.py b/app/core/llm.py index d833bf4..abdb939 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -103,6 +103,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = { "unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL, "cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL, "setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL, + "memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini", } diff --git a/app/core/memory_extraction.py b/app/core/memory_extraction.py new file mode 100644 index 0000000..1345b04 --- /dev/null +++ b/app/core/memory_extraction.py @@ -0,0 +1,456 @@ +"""Mem0-style Extract/Update pipeline — Phase 2. + +Runs after every ``store_episode`` call to distil durable facts, preferences, +routines, and relations from the latest conversation turn. + +Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)`` + +Design notes +------------ +- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate. +- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving). +- Zero-trust: never logs decrypted user content; relation subject/object labels are + treated as identifiers (safe to log per spec). +- Must not raise into the request path — caller wraps in asyncio.create_task(). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Literal + +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context +from app.core.llm import get_agent_llm, model_for_agent + +logger = logging.getLogger(__name__) + +# ── Fallback prompts (used when Langfuse unavailable) ───────────────────────── + +_EXTRACTION_FALLBACK = ( + "You are a memory extractor for a personal AI secretary. Given the last conversation " + "turn, the user's core memory, and recent episode summaries, identify durable facts, " + "preferences, routines, and person/project relations worth remembering.\n\n" + "Output JSON matching this schema exactly:\n" + '{{"candidates": [{{"type": "", ' + '"content": "", ' + '"target_tier": "", ' + '"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n' + "Rules:\n" + "- Skip small talk, greetings, one-off questions.\n" + "- Max 5 candidates per call.\n" + "- Only extract durable information (still true next week).\n" + "- For type=relation: subject/predicate/object required.\n" + "- Default confidence=0.7.\n\n" + "## Last turn\n{last_turn}\n\n" + "## Core memory (current)\n{core_memory}\n\n" + "## Recent episodes\n{recent_episodes}" +) + +_DECIDE_FALLBACK = ( + "You are a memory update decision engine. Given a new memory candidate and a list of " + "existing memories from the same tier, decide what action to take.\n\n" + "Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n" + "- ADD: new information not in existing memories.\n" + "- UPDATE: contradicts or supersedes an existing memory.\n" + "- DELETE: states something is no longer true.\n" + "- NOOP: already captured accurately.\n\n" + "## New candidate\n{candidate}\n\n" + "## Existing memories (same tier, top neighbours)\n{existing_memories}" +) + + +# ── Pydantic schemas ─────────────────────────────────────────────────────────── + +class MemoryCandidate(BaseModel): + type: Literal["fact", "preference", "relation", "routine"] + content: str + target_tier: Literal["core", "associative", "relational", "proactive"] + subject: str | None = None + predicate: str | None = None + object: str | None = None + confidence: float = Field(default=0.7, ge=0.0, le=1.0) + + +class ExtractionResult(BaseModel): + candidates: list[MemoryCandidate] = Field(default_factory=list) + + +# ── Task 2.1 — Extract candidates ───────────────────────────────────────────── + +async def extract_candidates( + last_turn: str, + core_memory: dict[str, str], + recent_episodes: list[str], +) -> ExtractionResult: + """Call gpt-4o-mini to extract memory candidates from the latest turn. + + Returns an ExtractionResult (may be empty on failure — never raises). + """ + core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)" + episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)" + + template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK) + + # Compile with Langfuse variable syntax ({{var}}) or fallback {var} + if prompt_obj is not None: + try: + system_text = prompt_obj.compile( + last_turn=last_turn, + core_memory=core_str, + recent_episodes=episodes_str, + ) + if isinstance(system_text, list): + system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict)) + except Exception as exc: + logger.warning("memory_extraction: compile failed: %s", exc) + system_text = template.format( + last_turn=last_turn, + core_memory=core_str, + recent_episodes=episodes_str, + ) + else: + system_text = template.format( + last_turn=last_turn, + core_memory=core_str, + recent_episodes=episodes_str, + ) + + llm = get_agent_llm("memory-extractor", temperature=0) + # Bind JSON mode so the model always returns parseable output. + llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined] + + lf = get_langfuse() + try: + from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 + messages = [ + SystemMessage(content=system_text), + HumanMessage(content="Extract memory candidates as JSON."), + ] + + if lf: + with lf.start_as_current_observation( + as_type="generation", + name="memory-extraction", + model=model_for_agent("memory-extractor"), + prompt=prompt_obj, + input=messages, + ) as gen: + response = await llm_json.ainvoke(messages) + gen.update(output=response.content, usage=extract_usage(response)) + else: + response = await llm_json.ainvoke(messages) + + raw = json.loads(response.content) + result = ExtractionResult.model_validate(raw) + logger.info("memory_extraction: extracted %d candidates", len(result.candidates)) + return result + + except Exception as exc: + logger.warning("memory_extraction: extract_candidates failed: %s", exc) + return ExtractionResult(candidates=[]) + + +# ── Task 2.2 — Decide action ────────────────────────────────────────────────── + +async def decide_action( + candidate: MemoryCandidate, + existing: list[str], +) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]: + """Decide what to do with a candidate given existing memories in the same tier. + + Short-circuits to ADD without an LLM call when existing is empty (cost saving). + Never raises. + """ + if not existing: + return "ADD" + + candidate_str = f"[{candidate.type}] {candidate.content}" + existing_str = "\n".join(f"- {m}" for m in existing) + + template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK) + + if prompt_obj is not None: + try: + system_text = prompt_obj.compile( + candidate=candidate_str, + existing_memories=existing_str, + ) + if isinstance(system_text, list): + system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict)) + except Exception as exc: + logger.warning("memory_extraction: decide compile failed: %s", exc) + system_text = template.format(candidate=candidate_str, existing_memories=existing_str) + else: + system_text = template.format(candidate=candidate_str, existing_memories=existing_str) + + llm = get_agent_llm("memory-extractor", temperature=0) + lf = get_langfuse() + + try: + from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 + messages = [ + SystemMessage(content=system_text), + HumanMessage(content="Decide action."), + ] + + if lf: + with lf.start_as_current_observation( + as_type="generation", + name="memory-decide-action", + model=model_for_agent("memory-extractor"), + prompt=prompt_obj, + input=messages, + ) as gen: + response = await llm.ainvoke(messages) + gen.update(output=response.content, usage=extract_usage(response)) + else: + response = await llm.ainvoke(messages) + + verb = response.content.strip().upper() + if verb in ("ADD", "UPDATE", "DELETE", "NOOP"): + return verb # type: ignore[return-value] + logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb) + return "ADD" + + except Exception as exc: + logger.warning("memory_extraction: decide_action failed: %s", exc) + return "ADD" + + +# ── Task 2.3 — Pipeline orchestrator ────────────────────────────────────────── + +async def run_extraction( + db: AsyncSession, + user_id: str, + last_user_msg: str, + last_assistant_msg: str, + session_id: str | None, +) -> None: + """Full Mem0-style extract/update pipeline for one conversation turn. + + Steps: + 1. Load core memory + last 5 episodes. + 2. extract_candidates() → up to 5 MemoryCandidate objects. + 3. For each candidate: find top-3 neighbours → decide_action() → apply. + 4. Trace via Langfuse. + + Never raises — wraps everything in try/except. + """ + try: + await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id) + except Exception as exc: + logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc) + + +async def _run_extraction_inner( + db: AsyncSession, + user_id: str, + last_user_msg: str, + last_assistant_msg: str, + session_id: str | None, +) -> None: + from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415 + + middleware = MemoryMiddleware(db) + fernet = await middleware._get_fernet(user_id) + if fernet is None: + logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id) + return + + # 1. Load context + core: dict[str, str] = await middleware._load_core(user_id, fernet) + episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id) + + last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}" + + lf = get_langfuse() + + async def _run(trace_id: str | None) -> dict[str, Any]: + # 2. Extract candidates + result = await extract_candidates(last_turn, core, episodes) + if not result.candidates: + logger.info("memory_extraction: no candidates user=%s", user_id) + return {"candidates": 0, "applied": 0} + + logger.info( + "memory_extraction: processing %d candidates user=%s trace=%s", + len(result.candidates), + user_id, + trace_id or "-", + ) + + # 3. Apply each candidate + applied = 0 + actions: list[str] = [] + for candidate in result.candidates: + try: + await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id) + applied += 1 + actions.append(f"{candidate.type}:{candidate.target_tier}") + except Exception as exc: + logger.warning( + "memory_extraction: apply failed candidate=%r user=%s: %s", + candidate.content[:80], + user_id, + exc, + ) + + logger.info( + "memory_extraction: applied %d/%d candidates user=%s", + applied, + len(result.candidates), + user_id, + ) + return {"candidates": len(result.candidates), "applied": applied, "actions": actions} + + with langfuse_context(user_id=user_id, session_id=session_id): + if lf: + with lf.start_as_current_observation( + as_type="span", + name="memory-extraction-pipeline", + input={"last_turn_preview": last_turn[:200]}, + ) as span: + summary = await _run(trace_id=span.id) + span.update(output=summary) + try: + lf.flush() + except Exception: + pass + else: + await _run(trace_id=None) + + +async def _apply_candidate( + middleware: Any, + db: AsyncSession, + user_id: str, + fernet: Any, + candidate: MemoryCandidate, + trace_id: str | None, +) -> None: + """Fetch neighbours, decide action, apply to the appropriate tier.""" + + neighbours: list[str] = [] + + if candidate.target_tier == "core": + # For core tier: neighbours are existing core block values for similar keys. + blocks = await middleware.list_core_blocks(user_id) + neighbours = [b["value"] for b in blocks[:3]] + + elif candidate.target_tier == "associative": + neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3) + + elif candidate.target_tier == "relational": + # Relation candidates handled specially — passed to upsert_relation directly. + # Neighbours: search by subject label if available. + neighbours = [] + + elif candidate.target_tier == "proactive": + neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3) + + action = await decide_action(candidate, neighbours) + logger.info( + "memory_extraction: candidate type=%s tier=%s action=%s", + candidate.type, + candidate.target_tier, + action, + ) + + if action == "NOOP": + return + + if candidate.target_tier == "relational": + # Always upsert relations — decide_action skipped (no neighbour search). + if candidate.subject and candidate.predicate and candidate.object: + await _upsert_relation_stub( + middleware, db, user_id, candidate, trace_id + ) + return + + if action in ("ADD", "UPDATE"): + if candidate.target_tier == "core": + # Derive a short key from the content (first 40 chars, snake_cased). + key = _content_to_key(candidate.content) + await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id) + + elif candidate.target_tier == "associative": + await middleware.store_associative(user_id, candidate.content) + + elif candidate.target_tier == "proactive": + await _store_proactive_stub(middleware, db, user_id, candidate, fernet) + + elif action == "DELETE": + if candidate.target_tier == "core": + key = _content_to_key(candidate.content) + await middleware.delete_core(user_id, key) + + +def _content_to_key(content: str) -> str: + """Derive a short snake_case key from a content string (first 40 chars).""" + import re # noqa: PLC0415 + slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_") + return slug or "memory" + + +async def _upsert_relation_stub( + middleware: Any, + db: AsyncSession, + user_id: str, + candidate: MemoryCandidate, + trace_id: str | None, +) -> None: + """Stub: upsert_relation will be fully wired in Phase 3. + + Called here so Phase 2 extraction pipeline already routes relation candidates + correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation(). + """ + if hasattr(middleware, "upsert_relation"): + await middleware.upsert_relation( + user_id=user_id, + subject=candidate.subject, + subject_type="unknown", + predicate=candidate.predicate, + object_=candidate.object, + object_type="unknown", + confidence=candidate.confidence, + ) + else: + logger.info( + "memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s", + candidate.subject, + candidate.predicate, + candidate.object, + ) + + +async def _store_proactive_stub( + middleware: Any, + db: AsyncSession, + user_id: str, + candidate: MemoryCandidate, + fernet: Any, +) -> None: + """Store a proactive pattern row directly (MemoryProactive model).""" + import uuid # noqa: PLC0415 + from app.models import MemoryProactive # noqa: PLC0415 + from app.core.memory_middleware import _encrypt # noqa: PLC0415 + + encrypted = _encrypt(fernet, candidate.content) + row = MemoryProactive( + id=str(uuid.uuid4()), + user_id=user_id, + pattern_encrypted=encrypted, + confidence=candidate.confidence, + source="inferred", + ) + db.add(row) + try: + await db.commit() + logger.info("memory_extraction: stored proactive pattern user=%s", user_id) + except Exception as exc: + logger.warning("memory_extraction: store proactive failed: %s", exc) + await db.rollback() diff --git a/app/core/memory_middleware.py b/app/core/memory_middleware.py index b879e2f..9780faa 100644 --- a/app/core/memory_middleware.py +++ b/app/core/memory_middleware.py @@ -18,6 +18,7 @@ Usage: from __future__ import annotations +import asyncio import logging import uuid from typing import Any @@ -27,6 +28,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models import ( + ExtractionQueue, MemoryAssociative, MemoryCore, MemoryEpisodic, @@ -106,7 +108,10 @@ class MemoryMiddleware: """Summarise and store a completed interaction in episodic memory. The summary is a simple heuristic concatenation (no LLM call) to keep - latency low. Full LLM summarisation can be added in a later step. + latency low. After committing the episode row, dispatches the Mem0-style + extraction pipeline: + - Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime). + - Free → enqueue an ExtractionQueue row for the daily cron. """ fernet = await self._get_fernet(user_id) if fernet is None: @@ -115,26 +120,95 @@ class MemoryMiddleware: summary = f"User: {message[:200]}\nAssistant: {response[:200]}" encrypted = _encrypt(fernet, summary) - row = MemoryEpisodic( + episode = MemoryEpisodic( id=str(uuid.uuid4()), user_id=user_id, summary_encrypted=encrypted, session_id=session_id, ) - self._db.add(row) + self._db.add(episode) + episode_id: str = episode.id try: await self._db.commit() user_dbg = await self._get_user_debug(user_id) + tier = user_dbg.get("tier") or "free" logger.info( "memory: store_episode trace=%s user=%s tier=%s session=%s", trace_id or "-", user_id, - user_dbg.get("tier") or "-", + tier, session_id, ) except Exception as exc: logger.error("memory: store_episode failed user=%s: %s", user_id, exc) await self._db.rollback() + return + + # ── Dispatch extraction pipeline (Phase 2) ──────────────────────────── + await self._dispatch_extraction( + user_id=user_id, + episode_id=episode_id, + last_user_msg=message, + last_assistant_msg=response, + session_id=session_id, + ) + + async def _dispatch_extraction( + self, + user_id: str, + episode_id: str, + last_user_msg: str, + last_assistant_msg: str, + session_id: str | None, + ) -> None: + """Route extraction to realtime task or batch queue based on user tier.""" + from app.billing.tier_manager import tier_manager # noqa: PLC0415 + + tier = await tier_manager.get_tier(user_id, self._db) + + if tier_manager.check_feature(tier, "realtime_extraction"): + # Pro/Power/Team: fire-and-forget in the background. + # Must open a fresh session — request session closes after handler returns. + from app.core.memory_extraction import run_extraction # noqa: PLC0415 + from app.db import async_session # noqa: PLC0415 + + async def _task() -> None: + try: + async with async_session() as fresh_db: + await run_extraction( + db=fresh_db, + user_id=user_id, + last_user_msg=last_user_msg, + last_assistant_msg=last_assistant_msg, + session_id=session_id, + ) + except Exception as exc: + logger.warning( + "memory: extraction task failed user=%s: %s", user_id, exc + ) + + asyncio.create_task(_task()) + logger.info("memory: realtime extraction dispatched user=%s", user_id) + else: + # Free tier: enqueue for daily batch cron. + queue_row = ExtractionQueue( + id=str(uuid.uuid4()), + user_id=user_id, + episode_id=episode_id, + ) + self._db.add(queue_row) + try: + await self._db.commit() + logger.info( + "memory: extraction enqueued (batch) user=%s episode=%s", + user_id, + episode_id, + ) + except Exception as exc: + logger.warning( + "memory: extraction queue insert failed user=%s: %s", user_id, exc + ) + await self._db.rollback() async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: """Upsert a core memory key/value for a user.""" diff --git a/app/models.py b/app/models.py index 98e713d..d5f6f77 100644 --- a/app/models.py +++ b/app/models.py @@ -351,6 +351,28 @@ class MemoryProactive(Base): ) +class ExtractionQueue(Base): + """Batch extraction queue for Free-tier users (Phase 2). + + Pro/Power/Team users get realtime asyncio.create_task() extraction. + Free users get a queue row here; a daily cron (Phase 5) drains it. + """ + + __tablename__ = "extraction_queue" + + id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid) + user_id: Mapped[str] = mapped_column( + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, index=True, + ) + episode_id: Mapped[str | None] = mapped_column( + Uuid(as_uuid=False), nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + class Plugin(Base): """Plugin marketplace catalog entry.""" diff --git a/tests/test_memory_extraction.py b/tests/test_memory_extraction.py new file mode 100644 index 0000000..def13ab --- /dev/null +++ b/tests/test_memory_extraction.py @@ -0,0 +1,345 @@ +"""Tests for Phase 2 — Mem0-style Extract/Update pipeline. + +Coverage: + 2.1 extract_candidates returns valid ExtractionResult with mocked LLM. + 2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing). + 2.3 run_extraction end-to-end with mocked LLM writes expected rows. + 2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row. +""" + +from __future__ import annotations + +import json +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from cryptography.fernet import Fernet +from sqlalchemy import select + +from app.core.memory_extraction import ( + ExtractionResult, + MemoryCandidate, + decide_action, + extract_candidates, + run_extraction, +) +from app.core.memory_middleware import MemoryMiddleware +from app.db import get_session +from app.main import app +from app.models import ExtractionQueue, MemoryCore, User +from tests.conftest import TEST_USER_IDS + + +PRO_USER_ID = TEST_USER_IDS["pro"] +FREE_USER_ID = TEST_USER_IDS["free"] +_FERNET_KEY = Fernet.generate_key().decode() + + +# ── DB override ─────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _override_db(db_session): + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +@pytest_asyncio.fixture +async def pro_user(db_session): + """Update the seeded pro user to have an encryption_key.""" + result = await db_session.execute(select(User).where(User.id == PRO_USER_ID)) + user = result.scalar_one() + user.encryption_key = _FERNET_KEY + await db_session.commit() + return user + + +@pytest_asyncio.fixture +async def free_user(db_session): + """Update the seeded free user to have an encryption_key.""" + result = await db_session.execute(select(User).where(User.id == FREE_USER_ID)) + user = result.scalar_one() + user.encryption_key = _FERNET_KEY + await db_session.commit() + return user + + +def _make_llm_response(content: str) -> MagicMock: + msg = MagicMock() + msg.content = content + msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + return msg + + +# ── TASK 2.1 — extract_candidates ──────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_extract_candidates_returns_valid_result(): + payload = { + "candidates": [ + { + "type": "fact", + "content": "User's CFO is Giulia", + "target_tier": "core", + "subject": None, + "predicate": None, + "object": None, + "confidence": 0.85, + } + ] + } + mock_response = _make_llm_response(json.dumps(payload)) + + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, + ): + mock_prompt.return_value = ( + "system prompt {last_turn} {core_memory} {recent_episodes}", + None, + ) + llm_instance = MagicMock() + llm_instance.bind.return_value = llm_instance + llm_instance.ainvoke = AsyncMock(return_value=mock_response) + mock_get_llm.return_value = llm_instance + + result = await extract_candidates( + last_turn="User: My CFO is Giulia\nAssistant: Noted.", + core_memory={}, + recent_episodes=[], + ) + + assert isinstance(result, ExtractionResult) + assert len(result.candidates) == 1 + assert result.candidates[0].type == "fact" + assert "Giulia" in result.candidates[0].content + assert result.candidates[0].confidence == 0.85 + + +@pytest.mark.asyncio +async def test_extract_candidates_returns_empty_on_llm_failure(): + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, + ): + mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None) + llm_instance = MagicMock() + llm_instance.bind.return_value = llm_instance + llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) + mock_get_llm.return_value = llm_instance + + result = await extract_candidates("turn", {}, []) + + assert isinstance(result, ExtractionResult) + assert result.candidates == [] + + +# ── TASK 2.2 — decide_action ───────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_decide_action_add_when_no_existing(): + candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core") + action = await decide_action(candidate, existing=[]) + assert action == "ADD" + + +@pytest.mark.asyncio +async def test_decide_action_noop(): + candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core") + mock_response = _make_llm_response("NOOP") + + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, + ): + mock_prompt.return_value = ("p {candidate} {existing_memories}", None) + llm_instance = MagicMock() + llm_instance.ainvoke = AsyncMock(return_value=mock_response) + mock_get_llm.return_value = llm_instance + + action = await decide_action(candidate, existing=["CFO is Giulia"]) + + assert action == "NOOP" + + +@pytest.mark.asyncio +async def test_decide_action_update(): + candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core") + mock_response = _make_llm_response("UPDATE") + + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, + ): + mock_prompt.return_value = ("p {candidate} {existing_memories}", None) + llm_instance = MagicMock() + llm_instance.ainvoke = AsyncMock(return_value=mock_response) + mock_get_llm.return_value = llm_instance + + action = await decide_action(candidate, existing=["CFO is Giulia"]) + + assert action == "UPDATE" + + +@pytest.mark.asyncio +async def test_decide_action_delete(): + candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core") + mock_response = _make_llm_response("DELETE") + + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, + ): + mock_prompt.return_value = ("p {candidate} {existing_memories}", None) + llm_instance = MagicMock() + llm_instance.ainvoke = AsyncMock(return_value=mock_response) + mock_get_llm.return_value = llm_instance + + action = await decide_action(candidate, existing=["CFO is Giulia"]) + + assert action == "DELETE" + + +@pytest.mark.asyncio +async def test_decide_action_defaults_add_on_llm_failure(): + candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core") + + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt, + ): + mock_prompt.return_value = ("p {candidate} {existing_memories}", None) + llm_instance = MagicMock() + llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down")) + mock_get_llm.return_value = llm_instance + + action = await decide_action(candidate, existing=["old memory"]) + + assert action == "ADD" + + +# ── TASK 2.3 — run_extraction end-to-end ───────────────────────────────────── + +@pytest.mark.asyncio +async def test_run_extraction_writes_core_candidate(db_session, pro_user): + """'My CFO is Giulia' → fact candidate → core row written.""" + fact_payload = { + "candidates": [ + { + "type": "fact", + "content": "User prefers morning meetings", + "target_tier": "core", + "confidence": 0.8, + } + ] + } + + def _mock_llm_response(content: str): + msg = MagicMock() + msg.content = content + msg.usage_metadata = {} + return msg + + call_count = 0 + + async def _ainvoke_side_effect(messages): + nonlocal call_count + call_count += 1 + if call_count == 1: + # extract_candidates call + return _mock_llm_response(json.dumps(fact_payload)) + # decide_action — no existing → short-circuits to ADD without LLM + return _mock_llm_response("ADD") + + with ( + patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm, + patch("app.core.memory_extraction.get_langfuse", return_value=None), + patch( + "app.core.memory_extraction.get_prompt_or_fallback", + side_effect=lambda name, fb: ( + ("p {last_turn} {core_memory} {recent_episodes}", None) + if name == "memory_extraction" + else ("p {candidate} {existing_memories}", None) + ), + ), + ): + llm_instance = MagicMock() + llm_instance.bind.return_value = llm_instance + llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect) + mock_get_llm.return_value = llm_instance + + await run_extraction( + db=db_session, + user_id=PRO_USER_ID, + last_user_msg="My CFO is Giulia", + last_assistant_msg="Noted, I will remember that.", + session_id="test-session", + ) + + # core row should exist + result = await db_session.execute( + select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID) + ) + rows = result.scalars().all() + assert len(rows) >= 1 + fernet = Fernet(_FERNET_KEY.encode()) + values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows] + assert any("morning meetings" in v for v in values) + + +# ── TASK 2.4 — dispatch ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_dispatch_realtime_for_pro(db_session, pro_user): + """Pro user: asyncio.create_task called (not queue row).""" + middleware = MemoryMiddleware(db_session) + + with ( + patch("app.core.memory_middleware.asyncio.create_task") as mock_task, + patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True), + ): + await middleware._dispatch_extraction( + user_id=PRO_USER_ID, + episode_id=str(uuid.uuid4()), + last_user_msg="hello", + last_assistant_msg="hi", + session_id=None, + ) + + mock_task.assert_called_once() + + +@pytest.mark.asyncio +async def test_dispatch_queue_for_free(db_session, free_user): + """Free user: ExtractionQueue row inserted.""" + middleware = MemoryMiddleware(db_session) + ep_id = str(uuid.uuid4()) + + with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False): + await middleware._dispatch_extraction( + user_id=FREE_USER_ID, + episode_id=ep_id, + last_user_msg="hello", + last_assistant_msg="hi", + session_id=None, + ) + + result = await db_session.execute( + select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID) + ) + rows = result.scalars().all() + assert len(rows) == 1 + assert rows[0].episode_id == ep_id