PHASE 2 — Mem0-style Extract/Update pipeline
This commit is contained in:
@@ -26,6 +26,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -35,6 +36,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -44,6 +46,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": True,
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -53,6 +56,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"batch_builder": True,
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
|
||||
@@ -103,6 +103,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
}
|
||||
|
||||
|
||||
|
||||
456
app/core/memory_extraction.py
Normal file
456
app/core/memory_extraction.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||
|
||||
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||
routines, and relations from the latest conversation turn.
|
||||
|
||||
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||
|
||||
Design notes
|
||||
------------
|
||||
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||
treated as identifiers (safe to log per spec).
|
||||
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||
|
||||
_EXTRACTION_FALLBACK = (
|
||||
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||
"Output JSON matching this schema exactly:\n"
|
||||
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||
'"content": "<short canonical statement>", '
|
||||
'"target_tier": "<core|associative|relational|proactive>", '
|
||||
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||
"Rules:\n"
|
||||
"- Skip small talk, greetings, one-off questions.\n"
|
||||
"- Max 5 candidates per call.\n"
|
||||
"- Only extract durable information (still true next week).\n"
|
||||
"- For type=relation: subject/predicate/object required.\n"
|
||||
"- Default confidence=0.7.\n\n"
|
||||
"## Last turn\n{last_turn}\n\n"
|
||||
"## Core memory (current)\n{core_memory}\n\n"
|
||||
"## Recent episodes\n{recent_episodes}"
|
||||
)
|
||||
|
||||
_DECIDE_FALLBACK = (
|
||||
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||
"existing memories from the same tier, decide what action to take.\n\n"
|
||||
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||
"- ADD: new information not in existing memories.\n"
|
||||
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||
"- DELETE: states something is no longer true.\n"
|
||||
"- NOOP: already captured accurately.\n\n"
|
||||
"## New candidate\n{candidate}\n\n"
|
||||
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||
|
||||
class MemoryCandidate(BaseModel):
|
||||
type: Literal["fact", "preference", "relation", "routine"]
|
||||
content: str
|
||||
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||
subject: str | None = None
|
||||
predicate: str | None = None
|
||||
object: str | None = None
|
||||
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||
|
||||
async def extract_candidates(
|
||||
last_turn: str,
|
||||
core_memory: dict[str, str],
|
||||
recent_episodes: list[str],
|
||||
) -> ExtractionResult:
|
||||
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||
|
||||
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||
"""
|
||||
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||
|
||||
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
else:
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
# Bind JSON mode so the model always returns parseable output.
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
lf = get_langfuse()
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Extract memory candidates as JSON."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-extraction",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
raw = json.loads(response.content)
|
||||
result = ExtractionResult.model_validate(raw)
|
||||
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||
return ExtractionResult(candidates=[])
|
||||
|
||||
|
||||
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||
|
||||
async def decide_action(
|
||||
candidate: MemoryCandidate,
|
||||
existing: list[str],
|
||||
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||
|
||||
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||
Never raises.
|
||||
"""
|
||||
if not existing:
|
||||
return "ADD"
|
||||
|
||||
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
candidate=candidate_str,
|
||||
existing_memories=existing_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
else:
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Decide action."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-decide-action",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
verb = response.content.strip().upper()
|
||||
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||
return verb # type: ignore[return-value]
|
||||
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||
return "ADD"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||
return "ADD"
|
||||
|
||||
|
||||
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||
|
||||
async def run_extraction(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||
|
||||
Steps:
|
||||
1. Load core memory + last 5 episodes.
|
||||
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||
4. Trace via Langfuse.
|
||||
|
||||
Never raises — wraps everything in try/except.
|
||||
"""
|
||||
try:
|
||||
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _run_extraction_inner(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
middleware = MemoryMiddleware(db)
|
||||
fernet = await middleware._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||
return
|
||||
|
||||
# 1. Load context
|
||||
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||
|
||||
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||
|
||||
lf = get_langfuse()
|
||||
|
||||
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||
# 2. Extract candidates
|
||||
result = await extract_candidates(last_turn, core, episodes)
|
||||
if not result.candidates:
|
||||
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||
return {"candidates": 0, "applied": 0}
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
trace_id or "-",
|
||||
)
|
||||
|
||||
# 3. Apply each candidate
|
||||
applied = 0
|
||||
actions: list[str] = []
|
||||
for candidate in result.candidates:
|
||||
try:
|
||||
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||
applied += 1
|
||||
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||
candidate.content[:80],
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: applied %d/%d candidates user=%s",
|
||||
applied,
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
)
|
||||
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||
|
||||
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="memory-extraction-pipeline",
|
||||
input={"last_turn_preview": last_turn[:200]},
|
||||
) as span:
|
||||
summary = await _run(trace_id=span.id)
|
||||
span.update(output=summary)
|
||||
try:
|
||||
lf.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
await _run(trace_id=None)
|
||||
|
||||
|
||||
async def _apply_candidate(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Any,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||
|
||||
neighbours: list[str] = []
|
||||
|
||||
if candidate.target_tier == "core":
|
||||
# For core tier: neighbours are existing core block values for similar keys.
|
||||
blocks = await middleware.list_core_blocks(user_id)
|
||||
neighbours = [b["value"] for b in blocks[:3]]
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||
|
||||
elif candidate.target_tier == "relational":
|
||||
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||
# Neighbours: search by subject label if available.
|
||||
neighbours = []
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||
|
||||
action = await decide_action(candidate, neighbours)
|
||||
logger.info(
|
||||
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||
candidate.type,
|
||||
candidate.target_tier,
|
||||
action,
|
||||
)
|
||||
|
||||
if action == "NOOP":
|
||||
return
|
||||
|
||||
if candidate.target_tier == "relational":
|
||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||
if candidate.subject and candidate.predicate and candidate.object:
|
||||
await _upsert_relation_stub(
|
||||
middleware, db, user_id, candidate, trace_id
|
||||
)
|
||||
return
|
||||
|
||||
if action in ("ADD", "UPDATE"):
|
||||
if candidate.target_tier == "core":
|
||||
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
await middleware.store_associative(user_id, candidate.content)
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||
|
||||
elif action == "DELETE":
|
||||
if candidate.target_tier == "core":
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.delete_core(user_id, key)
|
||||
|
||||
|
||||
def _content_to_key(content: str) -> str:
|
||||
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||
import re # noqa: PLC0415
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||
return slug or "memory"
|
||||
|
||||
|
||||
async def _upsert_relation_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Stub: upsert_relation will be fully wired in Phase 3.
|
||||
|
||||
Called here so Phase 2 extraction pipeline already routes relation candidates
|
||||
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
|
||||
"""
|
||||
if hasattr(middleware, "upsert_relation"):
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject,
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate,
|
||||
object_=candidate.object,
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
|
||||
|
||||
async def _store_proactive_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
fernet: Any,
|
||||
) -> None:
|
||||
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||
import uuid # noqa: PLC0415
|
||||
from app.models import MemoryProactive # noqa: PLC0415
|
||||
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||
|
||||
encrypted = _encrypt(fernet, candidate.content)
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=candidate.confidence,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||
await db.rollback()
|
||||
@@ -18,6 +18,7 @@ Usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -27,6 +28,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
@@ -106,7 +108,10 @@ class MemoryMiddleware:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. Full LLM summarisation can be added in a later step.
|
||||
latency low. After committing the episode row, dispatches the Mem0-style
|
||||
extraction pipeline:
|
||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
@@ -115,26 +120,95 @@ class MemoryMiddleware:
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
row = MemoryEpisodic(
|
||||
episode = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
self._db.add(episode)
|
||||
episode_id: str = episode.id
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
tier = user_dbg.get("tier") or "free"
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
tier,
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
return
|
||||
|
||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||
await self._dispatch_extraction(
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
last_user_msg=message,
|
||||
last_assistant_msg=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _dispatch_extraction(
|
||||
self,
|
||||
user_id: str,
|
||||
episode_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, self._db)
|
||||
|
||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||
# Pro/Power/Team: fire-and-forget in the background.
|
||||
# Must open a fresh session — request session closes after handler returns.
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
|
||||
async def _task() -> None:
|
||||
try:
|
||||
async with async_session() as fresh_db:
|
||||
await run_extraction(
|
||||
db=fresh_db,
|
||||
user_id=user_id,
|
||||
last_user_msg=last_user_msg,
|
||||
last_assistant_msg=last_assistant_msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||
)
|
||||
|
||||
asyncio.create_task(_task())
|
||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||
else:
|
||||
# Free tier: enqueue for daily batch cron.
|
||||
queue_row = ExtractionQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
)
|
||||
self._db.add(queue_row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||
user_id,
|
||||
episode_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
|
||||
@@ -351,6 +351,28 @@ class MemoryProactive(Base):
|
||||
)
|
||||
|
||||
|
||||
class ExtractionQueue(Base):
|
||||
"""Batch extraction queue for Free-tier users (Phase 2).
|
||||
|
||||
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
||||
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
||||
"""
|
||||
|
||||
__tablename__ = "extraction_queue"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False), nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin marketplace catalog entry."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user