PHASE 2 — Mem0-style Extract/Update pipeline
This commit is contained in:
@@ -53,6 +53,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
|
||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||
LLM_MODEL_SETUP_AGENT=
|
||||
|
||||
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
||||
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||
|
||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""add extraction_queue
|
||||
|
||||
Revision ID: 1f5975a4f3f4
|
||||
Revises: 005
|
||||
Create Date: 2026-04-16 17:26:25.790870
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1f5975a4f3f4'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'extraction_queue',
|
||||
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
||||
op.drop_table('extraction_queue')
|
||||
@@ -26,6 +26,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -35,6 +36,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -44,6 +46,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": True,
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -53,6 +56,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": True,
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
|
||||
@@ -103,6 +103,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
}
|
||||
|
||||
|
||||
|
||||
456
app/core/memory_extraction.py
Normal file
456
app/core/memory_extraction.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||
|
||||
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||
routines, and relations from the latest conversation turn.
|
||||
|
||||
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||
|
||||
Design notes
|
||||
------------
|
||||
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||
treated as identifiers (safe to log per spec).
|
||||
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||
|
||||
_EXTRACTION_FALLBACK = (
|
||||
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||
"Output JSON matching this schema exactly:\n"
|
||||
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||
'"content": "<short canonical statement>", '
|
||||
'"target_tier": "<core|associative|relational|proactive>", '
|
||||
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||
"Rules:\n"
|
||||
"- Skip small talk, greetings, one-off questions.\n"
|
||||
"- Max 5 candidates per call.\n"
|
||||
"- Only extract durable information (still true next week).\n"
|
||||
"- For type=relation: subject/predicate/object required.\n"
|
||||
"- Default confidence=0.7.\n\n"
|
||||
"## Last turn\n{last_turn}\n\n"
|
||||
"## Core memory (current)\n{core_memory}\n\n"
|
||||
"## Recent episodes\n{recent_episodes}"
|
||||
)
|
||||
|
||||
_DECIDE_FALLBACK = (
|
||||
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||
"existing memories from the same tier, decide what action to take.\n\n"
|
||||
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||
"- ADD: new information not in existing memories.\n"
|
||||
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||
"- DELETE: states something is no longer true.\n"
|
||||
"- NOOP: already captured accurately.\n\n"
|
||||
"## New candidate\n{candidate}\n\n"
|
||||
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||
|
||||
class MemoryCandidate(BaseModel):
|
||||
type: Literal["fact", "preference", "relation", "routine"]
|
||||
content: str
|
||||
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||
subject: str | None = None
|
||||
predicate: str | None = None
|
||||
object: str | None = None
|
||||
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||
|
||||
async def extract_candidates(
|
||||
last_turn: str,
|
||||
core_memory: dict[str, str],
|
||||
recent_episodes: list[str],
|
||||
) -> ExtractionResult:
|
||||
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||
|
||||
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||
"""
|
||||
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||
|
||||
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
else:
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
# Bind JSON mode so the model always returns parseable output.
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
lf = get_langfuse()
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Extract memory candidates as JSON."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-extraction",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
raw = json.loads(response.content)
|
||||
result = ExtractionResult.model_validate(raw)
|
||||
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||
return ExtractionResult(candidates=[])
|
||||
|
||||
|
||||
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||
|
||||
async def decide_action(
|
||||
candidate: MemoryCandidate,
|
||||
existing: list[str],
|
||||
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||
|
||||
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||
Never raises.
|
||||
"""
|
||||
if not existing:
|
||||
return "ADD"
|
||||
|
||||
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
candidate=candidate_str,
|
||||
existing_memories=existing_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
else:
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Decide action."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-decide-action",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
verb = response.content.strip().upper()
|
||||
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||
return verb # type: ignore[return-value]
|
||||
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||
return "ADD"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||
return "ADD"
|
||||
|
||||
|
||||
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||
|
||||
async def run_extraction(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||
|
||||
Steps:
|
||||
1. Load core memory + last 5 episodes.
|
||||
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||
4. Trace via Langfuse.
|
||||
|
||||
Never raises — wraps everything in try/except.
|
||||
"""
|
||||
try:
|
||||
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _run_extraction_inner(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
middleware = MemoryMiddleware(db)
|
||||
fernet = await middleware._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||
return
|
||||
|
||||
# 1. Load context
|
||||
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||
|
||||
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||
|
||||
lf = get_langfuse()
|
||||
|
||||
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||
# 2. Extract candidates
|
||||
result = await extract_candidates(last_turn, core, episodes)
|
||||
if not result.candidates:
|
||||
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||
return {"candidates": 0, "applied": 0}
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
trace_id or "-",
|
||||
)
|
||||
|
||||
# 3. Apply each candidate
|
||||
applied = 0
|
||||
actions: list[str] = []
|
||||
for candidate in result.candidates:
|
||||
try:
|
||||
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||
applied += 1
|
||||
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||
candidate.content[:80],
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: applied %d/%d candidates user=%s",
|
||||
applied,
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
)
|
||||
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||
|
||||
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="memory-extraction-pipeline",
|
||||
input={"last_turn_preview": last_turn[:200]},
|
||||
) as span:
|
||||
summary = await _run(trace_id=span.id)
|
||||
span.update(output=summary)
|
||||
try:
|
||||
lf.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
await _run(trace_id=None)
|
||||
|
||||
|
||||
async def _apply_candidate(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Any,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||
|
||||
neighbours: list[str] = []
|
||||
|
||||
if candidate.target_tier == "core":
|
||||
# For core tier: neighbours are existing core block values for similar keys.
|
||||
blocks = await middleware.list_core_blocks(user_id)
|
||||
neighbours = [b["value"] for b in blocks[:3]]
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||
|
||||
elif candidate.target_tier == "relational":
|
||||
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||
# Neighbours: search by subject label if available.
|
||||
neighbours = []
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||
|
||||
action = await decide_action(candidate, neighbours)
|
||||
logger.info(
|
||||
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||
candidate.type,
|
||||
candidate.target_tier,
|
||||
action,
|
||||
)
|
||||
|
||||
if action == "NOOP":
|
||||
return
|
||||
|
||||
if candidate.target_tier == "relational":
|
||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||
if candidate.subject and candidate.predicate and candidate.object:
|
||||
await _upsert_relation_stub(
|
||||
middleware, db, user_id, candidate, trace_id
|
||||
)
|
||||
return
|
||||
|
||||
if action in ("ADD", "UPDATE"):
|
||||
if candidate.target_tier == "core":
|
||||
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
await middleware.store_associative(user_id, candidate.content)
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||
|
||||
elif action == "DELETE":
|
||||
if candidate.target_tier == "core":
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.delete_core(user_id, key)
|
||||
|
||||
|
||||
def _content_to_key(content: str) -> str:
|
||||
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||
import re # noqa: PLC0415
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||
return slug or "memory"
|
||||
|
||||
|
||||
async def _upsert_relation_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Stub: upsert_relation will be fully wired in Phase 3.
|
||||
|
||||
Called here so Phase 2 extraction pipeline already routes relation candidates
|
||||
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
|
||||
"""
|
||||
if hasattr(middleware, "upsert_relation"):
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject,
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate,
|
||||
object_=candidate.object,
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
|
||||
|
||||
async def _store_proactive_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
fernet: Any,
|
||||
) -> None:
|
||||
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||
import uuid # noqa: PLC0415
|
||||
from app.models import MemoryProactive # noqa: PLC0415
|
||||
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||
|
||||
encrypted = _encrypt(fernet, candidate.content)
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=candidate.confidence,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||
await db.rollback()
|
||||
@@ -18,6 +18,7 @@ Usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -27,6 +28,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
@@ -106,7 +108,10 @@ class MemoryMiddleware:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. Full LLM summarisation can be added in a later step.
|
||||
latency low. After committing the episode row, dispatches the Mem0-style
|
||||
extraction pipeline:
|
||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
@@ -115,26 +120,95 @@ class MemoryMiddleware:
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
row = MemoryEpisodic(
|
||||
episode = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
self._db.add(episode)
|
||||
episode_id: str = episode.id
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
tier = user_dbg.get("tier") or "free"
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
tier,
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
return
|
||||
|
||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||
await self._dispatch_extraction(
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
last_user_msg=message,
|
||||
last_assistant_msg=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _dispatch_extraction(
|
||||
self,
|
||||
user_id: str,
|
||||
episode_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, self._db)
|
||||
|
||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||
# Pro/Power/Team: fire-and-forget in the background.
|
||||
# Must open a fresh session — request session closes after handler returns.
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
|
||||
async def _task() -> None:
|
||||
try:
|
||||
async with async_session() as fresh_db:
|
||||
await run_extraction(
|
||||
db=fresh_db,
|
||||
user_id=user_id,
|
||||
last_user_msg=last_user_msg,
|
||||
last_assistant_msg=last_assistant_msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||
)
|
||||
|
||||
asyncio.create_task(_task())
|
||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||
else:
|
||||
# Free tier: enqueue for daily batch cron.
|
||||
queue_row = ExtractionQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
)
|
||||
self._db.add(queue_row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||
user_id,
|
||||
episode_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
|
||||
@@ -351,6 +351,28 @@ class MemoryProactive(Base):
|
||||
)
|
||||
|
||||
|
||||
class ExtractionQueue(Base):
|
||||
"""Batch extraction queue for Free-tier users (Phase 2).
|
||||
|
||||
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
||||
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
||||
"""
|
||||
|
||||
__tablename__ = "extraction_queue"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False), nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin marketplace catalog entry."""
|
||||
|
||||
|
||||
345
tests/test_memory_extraction.py
Normal file
345
tests/test_memory_extraction.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
|
||||
|
||||
Coverage:
|
||||
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
|
||||
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
|
||||
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
|
||||
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_extraction import (
|
||||
ExtractionResult,
|
||||
MemoryCandidate,
|
||||
decide_action,
|
||||
extract_candidates,
|
||||
run_extraction,
|
||||
)
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import ExtractionQueue, MemoryCore, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user(db_session):
|
||||
"""Update the seeded pro user to have an encryption_key."""
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def free_user(db_session):
|
||||
"""Update the seeded free user to have an encryption_key."""
|
||||
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _make_llm_response(content: str) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
return msg
|
||||
|
||||
|
||||
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_candidates_returns_valid_result():
|
||||
payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"type": "fact",
|
||||
"content": "User's CFO is Giulia",
|
||||
"target_tier": "core",
|
||||
"subject": None,
|
||||
"predicate": None,
|
||||
"object": None,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_response = _make_llm_response(json.dumps(payload))
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = (
|
||||
"system prompt {last_turn} {core_memory} {recent_episodes}",
|
||||
None,
|
||||
)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
result = await extract_candidates(
|
||||
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
|
||||
core_memory={},
|
||||
recent_episodes=[],
|
||||
)
|
||||
|
||||
assert isinstance(result, ExtractionResult)
|
||||
assert len(result.candidates) == 1
|
||||
assert result.candidates[0].type == "fact"
|
||||
assert "Giulia" in result.candidates[0].content
|
||||
assert result.candidates[0].confidence == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_candidates_returns_empty_on_llm_failure():
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
result = await extract_candidates("turn", {}, [])
|
||||
|
||||
assert isinstance(result, ExtractionResult)
|
||||
assert result.candidates == []
|
||||
|
||||
|
||||
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_add_when_no_existing():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||
action = await decide_action(candidate, existing=[])
|
||||
assert action == "ADD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_noop():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||
mock_response = _make_llm_response("NOOP")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "NOOP"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_update():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||
mock_response = _make_llm_response("UPDATE")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "UPDATE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_delete():
|
||||
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
|
||||
mock_response = _make_llm_response("DELETE")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "DELETE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_defaults_add_on_llm_failure():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["old memory"])
|
||||
|
||||
assert action == "ADD"
|
||||
|
||||
|
||||
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
|
||||
"""'My CFO is Giulia' → fact candidate → core row written."""
|
||||
fact_payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"type": "fact",
|
||||
"content": "User prefers morning meetings",
|
||||
"target_tier": "core",
|
||||
"confidence": 0.8,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def _mock_llm_response(content: str):
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {}
|
||||
return msg
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _ainvoke_side_effect(messages):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# extract_candidates call
|
||||
return _mock_llm_response(json.dumps(fact_payload))
|
||||
# decide_action — no existing → short-circuits to ADD without LLM
|
||||
return _mock_llm_response("ADD")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch(
|
||||
"app.core.memory_extraction.get_prompt_or_fallback",
|
||||
side_effect=lambda name, fb: (
|
||||
("p {last_turn} {core_memory} {recent_episodes}", None)
|
||||
if name == "memory_extraction"
|
||||
else ("p {candidate} {existing_memories}", None)
|
||||
),
|
||||
),
|
||||
):
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
await run_extraction(
|
||||
db=db_session,
|
||||
user_id=PRO_USER_ID,
|
||||
last_user_msg="My CFO is Giulia",
|
||||
last_assistant_msg="Noted, I will remember that.",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
# core row should exist
|
||||
result = await db_session.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) >= 1
|
||||
fernet = Fernet(_FERNET_KEY.encode())
|
||||
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
|
||||
assert any("morning meetings" in v for v in values)
|
||||
|
||||
|
||||
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_realtime_for_pro(db_session, pro_user):
|
||||
"""Pro user: asyncio.create_task called (not queue row)."""
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
|
||||
with (
|
||||
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
|
||||
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
|
||||
):
|
||||
await middleware._dispatch_extraction(
|
||||
user_id=PRO_USER_ID,
|
||||
episode_id=str(uuid.uuid4()),
|
||||
last_user_msg="hello",
|
||||
last_assistant_msg="hi",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
mock_task.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_queue_for_free(db_session, free_user):
|
||||
"""Free user: ExtractionQueue row inserted."""
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ep_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
|
||||
await middleware._dispatch_extraction(
|
||||
user_id=FREE_USER_ID,
|
||||
episode_id=ep_id,
|
||||
last_user_msg="hello",
|
||||
last_assistant_msg="hi",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].episode_id == ep_id
|
||||
Reference in New Issue
Block a user