734 lines
28 KiB
Python
734 lines
28 KiB
Python
"""Memory Middleware — enrich requests with memory context and store interactions.
|
|
|
|
Four-tier memory model (MemGPT-style):
|
|
core — persistent key/value user preferences, always injected
|
|
associative — semantic similarity search via pgvector (top-k)
|
|
episodic — recent session summaries (last N)
|
|
proactive — behavioral patterns above confidence threshold
|
|
|
|
All memory content is encrypted at rest using the per-user Fernet key
|
|
stored in User.encryption_key. Decryption happens in-memory only.
|
|
|
|
Usage:
|
|
memory = MemoryMiddleware(db_session)
|
|
context = await memory.enrich_context(user_id, message)
|
|
# ... run agent ...
|
|
await memory.store_episode(user_id, session_id, message, response)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models import (
|
|
ExtractionQueue,
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
MemoryRelation,
|
|
User,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
# Tuning constants
|
|
_ASSOCIATIVE_TOP_K = 5
|
|
_EPISODIC_RECENT_N = 10
|
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|
|
|
|
|
class MemoryMiddleware:
|
|
"""Enrich orchestrator context with memory and persist interactions after."""
|
|
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self._db = db
|
|
|
|
# ── Public API ────────────────────────────────────────────────────────────
|
|
|
|
async def enrich_context(
|
|
self,
|
|
user_id: str,
|
|
message: str,
|
|
trace_id: str | None = None,
|
|
session_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
|
|
|
Returns a dict with keys:
|
|
core_memory — {key: plaintext_value, ...}
|
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
|
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
|
"""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return {}
|
|
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
user_tier: str = user_dbg.get("tier") or "free"
|
|
|
|
core = await self._load_core(user_id, fernet)
|
|
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
|
proactive = await self._load_proactive(user_id, fernet)
|
|
relational = await self._load_relational(user_id, user_tier=user_tier)
|
|
|
|
logger.info(
|
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
|
trace_id or "-",
|
|
user_id,
|
|
user_tier,
|
|
len(core),
|
|
len(associative),
|
|
len(episodic),
|
|
len(proactive),
|
|
len(relational),
|
|
)
|
|
|
|
return {
|
|
"core_memory": core,
|
|
"associative_memory": associative,
|
|
"episodic_memory": episodic,
|
|
"proactive_hints": proactive,
|
|
"relational_memory": relational,
|
|
}
|
|
|
|
async def store_episode(
|
|
self,
|
|
user_id: str,
|
|
session_id: str,
|
|
message: str,
|
|
response: str,
|
|
trace_id: str | None = None,
|
|
) -> None:
|
|
"""Summarise and store a completed interaction in episodic memory.
|
|
|
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
|
latency low. After committing the episode row, dispatches the Mem0-style
|
|
extraction pipeline:
|
|
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
|
- Free → enqueue an ExtractionQueue row for the daily cron.
|
|
"""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
|
encrypted = _encrypt(fernet, summary)
|
|
|
|
episode = MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
summary_encrypted=encrypted,
|
|
session_id=session_id,
|
|
)
|
|
self._db.add(episode)
|
|
episode_id: str = episode.id
|
|
try:
|
|
await self._db.commit()
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
tier = user_dbg.get("tier") or "free"
|
|
logger.info(
|
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
|
trace_id or "-",
|
|
user_id,
|
|
tier,
|
|
session_id,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
return
|
|
|
|
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
|
await self._dispatch_extraction(
|
|
user_id=user_id,
|
|
episode_id=episode_id,
|
|
last_user_msg=message,
|
|
last_assistant_msg=response,
|
|
session_id=session_id,
|
|
)
|
|
|
|
async def _dispatch_extraction(
|
|
self,
|
|
user_id: str,
|
|
episode_id: str,
|
|
last_user_msg: str,
|
|
last_assistant_msg: str,
|
|
session_id: str | None,
|
|
) -> None:
|
|
"""Route extraction to realtime task or batch queue based on user tier."""
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
|
|
tier = await tier_manager.get_tier(user_id, self._db)
|
|
|
|
if tier_manager.check_feature(tier, "realtime_extraction"):
|
|
# Pro/Power/Team: fire-and-forget in the background.
|
|
# Must open a fresh session — request session closes after handler returns.
|
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
|
from app.db import async_session # noqa: PLC0415
|
|
|
|
async def _task() -> None:
|
|
try:
|
|
async with async_session() as fresh_db:
|
|
await run_extraction(
|
|
db=fresh_db,
|
|
user_id=user_id,
|
|
last_user_msg=last_user_msg,
|
|
last_assistant_msg=last_assistant_msg,
|
|
session_id=session_id,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory: extraction task failed user=%s: %s", user_id, exc
|
|
)
|
|
|
|
asyncio.create_task(_task())
|
|
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
|
else:
|
|
# Free tier: enqueue for daily batch cron.
|
|
queue_row = ExtractionQueue(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
episode_id=episode_id,
|
|
)
|
|
self._db.add(queue_row)
|
|
try:
|
|
await self._db.commit()
|
|
logger.info(
|
|
"memory: extraction enqueued (batch) user=%s episode=%s",
|
|
user_id,
|
|
episode_id,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
|
)
|
|
await self._db.rollback()
|
|
|
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
|
"""Upsert a core memory key/value for a user."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, value)
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == key,
|
|
)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is not None:
|
|
existing.value_encrypted = encrypted
|
|
else:
|
|
self._db.add(MemoryCore(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
key=key,
|
|
value_encrypted=encrypted,
|
|
))
|
|
try:
|
|
await self._db.commit()
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
logger.info(
|
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
|
trace_id or "-",
|
|
user_id,
|
|
user_dbg.get("tier") or "-",
|
|
key,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
|
await self._db.rollback()
|
|
|
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
"""Return core memory as editable blocks (label/value)."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore)
|
|
.where(MemoryCore.user_id == user_id)
|
|
.order_by(MemoryCore.key.asc())
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[dict[str, str]] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out.append({"label": row.key, "value": plaintext})
|
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
|
return out
|
|
|
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
"""Return a single core memory block value by label."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return None
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == label,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
|
return None
|
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
|
return value
|
|
|
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
"""Delete a core memory block by label. Returns True if deleted."""
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == label,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
|
return False
|
|
|
|
await self._db.delete(row)
|
|
try:
|
|
await self._db.commit()
|
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
|
return True
|
|
except Exception as exc:
|
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
await self._db.rollback()
|
|
return False
|
|
|
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
"""Append content to a core block, creating it if missing."""
|
|
current = await self.get_core_block(user_id, label)
|
|
if current is None:
|
|
await self.update_core(user_id, label, content)
|
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
|
return
|
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
|
|
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
|
current = await self.get_core_block(user_id, label)
|
|
if current is None or old not in current:
|
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
|
return False
|
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
|
return True
|
|
|
|
async def store_associative(
|
|
self,
|
|
user_id: str,
|
|
content: str,
|
|
entity_type: str | None = None,
|
|
entity_id: str | None = None,
|
|
) -> None:
|
|
"""Store associative memory; embed if user tier has real_embeddings."""
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
|
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, content)
|
|
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
user_tier = user_dbg.get("tier") or "free"
|
|
|
|
embedding: list[float] | None = None
|
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
|
embedding = await embed_text(content)
|
|
|
|
row = MemoryAssociative(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
content_encrypted=encrypted,
|
|
embedding=embedding,
|
|
entity_type=entity_type,
|
|
entity_id=entity_id,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
logger.info(
|
|
"memory: store_associative user=%s embedded=%s",
|
|
user_id,
|
|
embedding is not None,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def upsert_relation(
|
|
self,
|
|
user_id: str,
|
|
subject: str,
|
|
subject_type: str,
|
|
predicate: str,
|
|
object_: str,
|
|
object_type: str,
|
|
*,
|
|
confidence: float = 0.7,
|
|
source_episode_id: str | None = None,
|
|
notes: str | None = None,
|
|
) -> None:
|
|
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
|
|
|
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
|
notes is optional; encrypted with user Fernet if provided.
|
|
"""
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
user_tier = user_dbg.get("tier") or "free"
|
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
|
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
|
return
|
|
|
|
notes_encrypted: bytes | None = None
|
|
if notes:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet:
|
|
notes_encrypted = fernet.encrypt(notes.encode())
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryRelation).where(
|
|
MemoryRelation.user_id == user_id,
|
|
MemoryRelation.subject_label == subject,
|
|
MemoryRelation.predicate == predicate,
|
|
MemoryRelation.object_label == object_,
|
|
)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
|
|
if existing is not None:
|
|
existing.subject_type = subject_type
|
|
existing.object_type = object_type
|
|
existing.confidence = confidence
|
|
existing.last_confirmed_at = _now()
|
|
if notes_encrypted is not None:
|
|
existing.notes_encrypted = notes_encrypted
|
|
else:
|
|
self._db.add(MemoryRelation(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
subject_label=subject,
|
|
subject_type=subject_type,
|
|
predicate=predicate,
|
|
object_label=object_,
|
|
object_type=object_type,
|
|
confidence=confidence,
|
|
source_episode_id=source_episode_id,
|
|
notes_encrypted=notes_encrypted,
|
|
))
|
|
|
|
try:
|
|
await self._db.commit()
|
|
logger.info(
|
|
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
|
user_id, subject, predicate, object_,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def query_relations(
|
|
self,
|
|
user_id: str,
|
|
subject: str | None = None,
|
|
predicate: str | None = None,
|
|
object_: str | None = None,
|
|
limit: int = 20,
|
|
) -> list[MemoryRelation]:
|
|
"""Query relation rows for a user with optional filters."""
|
|
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
if subject is not None:
|
|
q = q.where(MemoryRelation.subject_label == subject)
|
|
if predicate is not None:
|
|
q = q.where(MemoryRelation.predicate == predicate)
|
|
if object_ is not None:
|
|
q = q.where(MemoryRelation.object_label == object_)
|
|
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
|
result = await self._db.execute(q)
|
|
return list(result.scalars().all())
|
|
|
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
"""Insert a long-term archival memory entry."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, content)
|
|
row = MemoryAssociative(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
content_encrypted=encrypted,
|
|
embedding=None,
|
|
entity_type=source,
|
|
entity_id=None,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
|
except Exception as exc:
|
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative)
|
|
.where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc())
|
|
.limit(100)
|
|
)
|
|
rows = result.scalars().all()
|
|
needle = query.strip().lower()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is None:
|
|
continue
|
|
if not needle or needle in plaintext.lower():
|
|
out.append(plaintext)
|
|
if len(out) >= max(top_k, 1):
|
|
break
|
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
return out
|
|
|
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
"""Search recall memory (episodic summaries) by keyword."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryEpisodic)
|
|
.where(MemoryEpisodic.user_id == user_id)
|
|
.order_by(MemoryEpisodic.created_at.desc())
|
|
.limit(100)
|
|
)
|
|
rows = result.scalars().all()
|
|
needle = query.strip().lower()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is None:
|
|
continue
|
|
if not needle or needle in plaintext.lower():
|
|
out.append(plaintext)
|
|
if len(out) >= max(top_k, 1):
|
|
break
|
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
return out
|
|
|
|
# ── Private helpers ───────────────────────────────────────────────────────
|
|
|
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
|
return None
|
|
return Fernet(user.encryption_key.encode())
|
|
|
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
|
"""Load lightweight user debug fields for trace logs."""
|
|
from app.config.settings import settings # noqa: PLC0415
|
|
from app.models import Subscription # noqa: PLC0415
|
|
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None:
|
|
return {"tier": None}
|
|
|
|
sub_result = await self._db.execute(
|
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
)
|
|
sub_tier: str | None = sub_result.scalar_one_or_none()
|
|
if sub_tier:
|
|
tier = sub_tier
|
|
elif settings.ENV == "dev":
|
|
tier = "power"
|
|
else:
|
|
tier = user.tier or "free"
|
|
|
|
return {"tier": tier}
|
|
|
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: dict[str, str] = {}
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out[row.key] = plaintext
|
|
return out
|
|
|
|
async def _load_associative(
|
|
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
|
) -> list[str]:
|
|
"""Load top-k associative memories.
|
|
|
|
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
|
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
|
"""
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
|
|
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
|
vec = await embed_text(message)
|
|
if vec is not None:
|
|
try:
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative)
|
|
.where(
|
|
MemoryAssociative.user_id == user_id,
|
|
MemoryAssociative.embedding.isnot(None),
|
|
)
|
|
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
|
.limit(_ASSOCIATIVE_TOP_K)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
logger.info(
|
|
"memory: _load_associative user=%s mode=vector hits=%d",
|
|
user_id,
|
|
len(out),
|
|
)
|
|
return out
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory: vector search failed user=%s, falling back to keyword: %s",
|
|
user_id,
|
|
exc,
|
|
)
|
|
|
|
# Keyword fallback: most recent rows
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative)
|
|
.where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc())
|
|
.limit(_ASSOCIATIVE_TOP_K)
|
|
)
|
|
rows = result.scalars().all()
|
|
out = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_episodic(
|
|
self,
|
|
user_id: str,
|
|
fernet: Fernet,
|
|
session_id: str | None = None,
|
|
) -> list[str]:
|
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
if session_id:
|
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
|
result = await self._db.execute(
|
|
query
|
|
.order_by(MemoryEpisodic.created_at.desc())
|
|
.limit(_EPISODIC_RECENT_N)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
|
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
|
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryRelation)
|
|
.where(MemoryRelation.user_id == user_id)
|
|
.order_by(MemoryRelation.confidence.desc())
|
|
.limit(10)
|
|
)
|
|
rows = result.scalars().all()
|
|
out = [
|
|
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
|
for r in rows
|
|
]
|
|
return out
|
|
|
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryProactive)
|
|
.where(
|
|
MemoryProactive.user_id == user_id,
|
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
|
)
|
|
.order_by(MemoryProactive.confidence.desc())
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
|
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
|
|
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
|
return fernet.encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
|
try:
|
|
return fernet.decrypt(ciphertext.encode()).decode()
|
|
except (InvalidToken, Exception) as exc:
|
|
logger.warning("memory: decrypt failed: %s", exc)
|
|
return None
|