314 lines
11 KiB
Python
314 lines
11 KiB
Python
"""Memory maintenance jobs — Phase 3/5.
|
|
|
|
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
|
|
|
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
|
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
|
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
|
|
|
All are safe to call manually or from tests; they never raise.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models import MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Decay parameters for relations
|
|
_DECAY_FACTOR = 0.95
|
|
_DECAY_PERIOD_DAYS = 30
|
|
_PRUNE_THRESHOLD = 0.2
|
|
|
|
# Proactive pattern decay: 10 % per 7 days since last sighting
|
|
_PROACTIVE_DECAY_FACTOR = 0.9
|
|
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
|
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
|
|
|
# Mining: require at least this many episodes to attempt pattern extraction
|
|
_MIN_EPISODES_FOR_MINING = 3
|
|
_MINING_LOOKBACK_DAYS = 30
|
|
|
|
|
|
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
|
"""Apply confidence decay to all relation rows for a user.
|
|
|
|
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
|
Rows whose confidence falls below 0.2 are deleted.
|
|
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _decay_relations_inner(db, user_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
|
|
|
|
|
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
|
result = await db.execute(
|
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
now = datetime.now(timezone.utc)
|
|
deleted = 0
|
|
decayed = 0
|
|
|
|
for row in rows:
|
|
reference = row.last_confirmed_at or row.created_at
|
|
if reference is None:
|
|
continue
|
|
if reference.tzinfo is None:
|
|
reference = reference.replace(tzinfo=timezone.utc)
|
|
|
|
days_elapsed = (now - reference).days
|
|
if days_elapsed < _DECAY_PERIOD_DAYS:
|
|
continue
|
|
|
|
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
|
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
|
|
|
if new_confidence < _PRUNE_THRESHOLD:
|
|
await db.delete(row)
|
|
deleted += 1
|
|
logger.info(
|
|
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
|
"confidence=%.3f (below threshold)",
|
|
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
|
)
|
|
else:
|
|
row.confidence = new_confidence
|
|
decayed += 1
|
|
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
|
user_id, decayed, deleted,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
|
await db.rollback()
|
|
|
|
|
|
async def drain_extraction_queue(db: AsyncSession) -> None:
|
|
"""Process pending ExtractionQueue rows for Free-tier users.
|
|
|
|
Each row corresponds to a stored episode that should be fed through the
|
|
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _drain_extraction_queue_inner(db)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
|
|
|
|
|
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
|
from app.models import ExtractionQueue # noqa: PLC0415
|
|
|
|
result = await db.execute(select(ExtractionQueue))
|
|
rows = result.scalars().all()
|
|
|
|
if not rows:
|
|
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
|
return
|
|
|
|
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
|
|
|
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
|
|
|
processed = 0
|
|
for row in rows:
|
|
try:
|
|
await run_extraction(
|
|
db=db,
|
|
user_id=row.user_id,
|
|
last_user_msg="",
|
|
last_assistant_msg="",
|
|
session_id=None,
|
|
)
|
|
await db.delete(row)
|
|
await db.commit()
|
|
processed += 1
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_maintenance: drain failed row=%s user=%s: %s",
|
|
row.id, row.user_id, exc,
|
|
)
|
|
await db.rollback()
|
|
|
|
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
|
|
|
|
|
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
|
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
|
|
|
Steps:
|
|
1. Gate on proactive_mining tier feature.
|
|
2. Load + decrypt last 30 days of episodic summaries.
|
|
3. Call gpt-4o-mini to identify recurring patterns.
|
|
4. Encrypt and store each pattern in memory_proactive.
|
|
5. Apply decay to existing proactive rows.
|
|
|
|
Never raises — wraps in try/except.
|
|
"""
|
|
try:
|
|
await _mine_proactive_patterns_inner(db, user_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
|
|
|
|
|
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
|
|
tier = await tier_manager.get_tier(user_id, db)
|
|
if not tier_manager.check_feature(tier, "proactive_mining"):
|
|
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
|
return
|
|
|
|
# Load user Fernet key
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
|
return
|
|
|
|
fernet = Fernet(user.encryption_key.encode())
|
|
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
|
|
|
episodes_result = await db.execute(
|
|
select(MemoryEpisodic)
|
|
.where(
|
|
MemoryEpisodic.user_id == user_id,
|
|
MemoryEpisodic.created_at >= cutoff,
|
|
)
|
|
.order_by(MemoryEpisodic.created_at.asc())
|
|
)
|
|
episode_rows = episodes_result.scalars().all()
|
|
|
|
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
|
logger.info(
|
|
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
|
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
|
)
|
|
return
|
|
|
|
summaries: list[str] = []
|
|
for ep in episode_rows:
|
|
try:
|
|
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
|
summaries.append(plaintext)
|
|
except Exception:
|
|
pass
|
|
|
|
if not summaries:
|
|
return
|
|
|
|
patterns = await _extract_proactive_patterns(summaries)
|
|
if not patterns:
|
|
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
|
return
|
|
|
|
stored = 0
|
|
for pattern_text in patterns:
|
|
try:
|
|
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
|
row = MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
pattern_encrypted=encrypted,
|
|
confidence=0.7,
|
|
source="inferred",
|
|
)
|
|
db.add(row)
|
|
stored += 1
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
|
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
|
user_id, stored,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
|
await db.rollback()
|
|
return
|
|
|
|
await _decay_proactive_patterns(db, user_id, fernet)
|
|
|
|
|
|
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
|
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
|
from app.core.llm import get_agent_llm # noqa: PLC0415
|
|
|
|
llm = get_agent_llm("memory-miner", temperature=0)
|
|
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
|
prompt = (
|
|
"You are analyzing conversation history for a personal AI secretary. "
|
|
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
|
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
|
"Return each pattern as a plain, short English sentence on its own line. "
|
|
"No numbering, no bullet points, no extra text.\n\n"
|
|
f"Conversation history:\n{combined}"
|
|
)
|
|
try:
|
|
response = await llm.ainvoke(prompt)
|
|
text = response.content if hasattr(response, "content") else str(response)
|
|
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
|
return lines[:5]
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
|
return []
|
|
|
|
|
|
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
|
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
|
result = await db.execute(
|
|
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
now = datetime.now(timezone.utc)
|
|
deleted = 0
|
|
decayed = 0
|
|
|
|
for row in rows:
|
|
reference = row.created_at
|
|
if reference is None:
|
|
continue
|
|
if reference.tzinfo is None:
|
|
reference = reference.replace(tzinfo=timezone.utc)
|
|
|
|
days_elapsed = (now - reference).days
|
|
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
|
continue
|
|
|
|
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
|
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
|
|
|
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
|
await db.delete(row)
|
|
deleted += 1
|
|
else:
|
|
row.confidence = new_confidence
|
|
decayed += 1
|
|
|
|
try:
|
|
await db.commit()
|
|
logger.info(
|
|
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
|
user_id, decayed, deleted,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
|
await db.rollback()
|