PHASE 2 — Mem0-style Extract/Update pipeline
This commit is contained in:
@@ -18,6 +18,7 @@ Usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -27,6 +28,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
@@ -106,7 +108,10 @@ class MemoryMiddleware:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. Full LLM summarisation can be added in a later step.
|
||||
latency low. After committing the episode row, dispatches the Mem0-style
|
||||
extraction pipeline:
|
||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
@@ -115,26 +120,95 @@ class MemoryMiddleware:
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
row = MemoryEpisodic(
|
||||
episode = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
self._db.add(episode)
|
||||
episode_id: str = episode.id
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
tier = user_dbg.get("tier") or "free"
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
tier,
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
return
|
||||
|
||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||
await self._dispatch_extraction(
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
last_user_msg=message,
|
||||
last_assistant_msg=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _dispatch_extraction(
|
||||
self,
|
||||
user_id: str,
|
||||
episode_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, self._db)
|
||||
|
||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||
# Pro/Power/Team: fire-and-forget in the background.
|
||||
# Must open a fresh session — request session closes after handler returns.
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
|
||||
async def _task() -> None:
|
||||
try:
|
||||
async with async_session() as fresh_db:
|
||||
await run_extraction(
|
||||
db=fresh_db,
|
||||
user_id=user_id,
|
||||
last_user_msg=last_user_msg,
|
||||
last_assistant_msg=last_assistant_msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||
)
|
||||
|
||||
asyncio.create_task(_task())
|
||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||
else:
|
||||
# Free tier: enqueue for daily batch cron.
|
||||
queue_row = ExtractionQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
)
|
||||
self._db.add(queue_row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||
user_id,
|
||||
episode_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
|
||||
Reference in New Issue
Block a user