PHASE 3 — relational tier (Mem0g-light)

This commit is contained in:
Roberto Musso
2026-04-17 17:04:27 +02:00
parent 741b9b87fb
commit 341ee140e5
10 changed files with 850 additions and 33 deletions

View File

@@ -55,6 +55,22 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}."
)
def _relational_memory_injection(context: dict[str, Any]) -> str:
"""Return a system-prompt paragraph listing known people/projects from relational memory.
Returns empty string when no relational rows or tier is Free.
Capped at 800 chars to control token spend.
"""
relations: list[str] = context.get("relational_memory") or []
if not relations:
return ""
body = "\n".join(f"- {r}" for r in relations)
section = f"\n\nKnown people & projects:\n{body}"
if len(section) > 800:
section = section[:797] + "..."
return section
_HOME_SYSTEM_PROMPT = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
@@ -904,6 +920,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
response = await _run_single_agent(
user_id=user_id,
@@ -922,6 +939,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
response = await _run_single_agent(
user_id=user_id,
@@ -946,6 +964,7 @@ async def run_home_stream(
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
@@ -979,6 +998,7 @@ async def run_floating_stream(
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False

View File

@@ -366,7 +366,7 @@ async def _apply_candidate(
if candidate.target_tier == "relational":
# Always upsert relations — decide_action skipped (no neighbour search).
if candidate.subject and candidate.predicate and candidate.object:
await _upsert_relation_stub(
await _upsert_relation(
middleware, db, user_id, candidate, trace_id
)
return
@@ -396,35 +396,29 @@ def _content_to_key(content: str) -> str:
return slug or "memory"
async def _upsert_relation_stub(
async def _upsert_relation(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Stub: upsert_relation will be fully wired in Phase 3.
Called here so Phase 2 extraction pipeline already routes relation candidates
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
"""
if hasattr(middleware, "upsert_relation"):
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject,
subject_type="unknown",
predicate=candidate.predicate,
object_=candidate.object,
object_type="unknown",
confidence=candidate.confidence,
)
else:
logger.info(
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject or "unknown",
subject_type="unknown",
predicate=candidate.predicate or "related_to",
object_=candidate.object or "unknown",
object_type="unknown",
confidence=candidate.confidence,
)
logger.info(
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
async def _store_proactive_stub(

View File

@@ -0,0 +1,102 @@
"""Memory maintenance jobs — Phase 3/5.
Two entrypoints called by the scheduler (APScheduler) registered in app/main.py:
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
Both are safe to call manually or from tests; they never raise.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import MemoryRelation
logger = logging.getLogger(__name__)
# Decay parameters
_DECAY_FACTOR = 0.95 # multiply confidence by this every _DECAY_PERIOD days
_DECAY_PERIOD_DAYS = 30 # period for one decay step
_PRUNE_THRESHOLD = 0.2 # rows below this confidence are deleted
async def decay_relations(db: AsyncSession, user_id: str) -> None:
"""Apply confidence decay to all relation rows for a user.
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
Rows whose confidence falls below 0.2 are deleted.
Never raises — wraps in try/except.
"""
try:
await _decay_relations_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
now = datetime.now(timezone.utc)
deleted = 0
decayed = 0
for row in rows:
reference = row.last_confirmed_at or row.created_at
if reference is None:
continue
# Ensure timezone-aware comparison
if reference.tzinfo is None:
reference = reference.replace(tzinfo=timezone.utc)
days_elapsed = (now - reference).days
if days_elapsed < _DECAY_PERIOD_DAYS:
continue
periods = days_elapsed // _DECAY_PERIOD_DAYS
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
if new_confidence < _PRUNE_THRESHOLD:
await db.delete(row)
deleted += 1
logger.info(
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
"confidence=%.3f (below threshold)",
row.id, user_id, row.subject_label, row.predicate, new_confidence,
)
else:
row.confidence = new_confidence
decayed += 1
try:
await db.commit()
logger.info(
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
user_id, decayed, deleted,
)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
await db.rollback()
async def drain_extraction_queue(db: AsyncSession) -> None:
"""Process pending ExtractionQueue rows for Free-tier users (Phase 5 stub).
Full implementation wired in Phase 5 when APScheduler is registered.
Currently logs count and returns.
"""
try:
from app.models import ExtractionQueue # noqa: PLC0415
result = await db.execute(select(ExtractionQueue))
rows = result.scalars().all()
logger.info("memory_maintenance: drain_extraction_queue pending=%d (Phase 5 cron)", len(rows))
except Exception as exc:
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)

View File

@@ -21,6 +21,7 @@ from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
@@ -33,11 +34,17 @@ from app.models import (
MemoryCore,
MemoryEpisodic,
MemoryProactive,
MemoryRelation,
User,
)
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
@@ -66,6 +73,7 @@ class MemoryMiddleware:
associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold)
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
@@ -78,9 +86,10 @@ class MemoryMiddleware:
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
relational = await self._load_relational(user_id, user_tier=user_tier)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
trace_id or "-",
user_id,
user_tier,
@@ -88,6 +97,7 @@ class MemoryMiddleware:
len(associative),
len(episodic),
len(proactive),
len(relational),
)
return {
@@ -95,6 +105,7 @@ class MemoryMiddleware:
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
"relational_memory": relational,
}
async def store_episode(
@@ -375,6 +386,99 @@ class MemoryMiddleware:
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def upsert_relation(
self,
user_id: str,
subject: str,
subject_type: str,
predicate: str,
object_: str,
object_type: str,
*,
confidence: float = 0.7,
source_episode_id: str | None = None,
notes: str | None = None,
) -> None:
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
subject_label / object_label are plaintext entity identifiers — not encrypted.
notes is optional; encrypted with user Fernet if provided.
"""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
user_dbg = await self._get_user_debug(user_id)
user_tier = user_dbg.get("tier") or "free"
if not tier_manager.check_feature(user_tier, "relational_memory"):
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
return
notes_encrypted: bytes | None = None
if notes:
fernet = await self._get_fernet(user_id)
if fernet:
notes_encrypted = fernet.encrypt(notes.encode())
result = await self._db.execute(
select(MemoryRelation).where(
MemoryRelation.user_id == user_id,
MemoryRelation.subject_label == subject,
MemoryRelation.predicate == predicate,
MemoryRelation.object_label == object_,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.subject_type = subject_type
existing.object_type = object_type
existing.confidence = confidence
existing.last_confirmed_at = _now()
if notes_encrypted is not None:
existing.notes_encrypted = notes_encrypted
else:
self._db.add(MemoryRelation(
id=str(uuid.uuid4()),
user_id=user_id,
subject_label=subject,
subject_type=subject_type,
predicate=predicate,
object_label=object_,
object_type=object_type,
confidence=confidence,
source_episode_id=source_episode_id,
notes_encrypted=notes_encrypted,
))
try:
await self._db.commit()
logger.info(
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
user_id, subject, predicate, object_,
)
except Exception as exc:
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def query_relations(
self,
user_id: str,
subject: str | None = None,
predicate: str | None = None,
object_: str | None = None,
limit: int = 20,
) -> list[MemoryRelation]:
"""Query relation rows for a user with optional filters."""
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
if subject is not None:
q = q.where(MemoryRelation.subject_label == subject)
if predicate is not None:
q = q.where(MemoryRelation.predicate == predicate)
if object_ is not None:
q = q.where(MemoryRelation.object_label == object_)
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
result = await self._db.execute(q)
return list(result.scalars().all())
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
@@ -463,13 +567,26 @@ class MemoryMiddleware:
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
from app.config.settings import settings # noqa: PLC0415
from app.models import Subscription # noqa: PLC0415
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
sub_result = await self._db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
sub_tier: str | None = sub_result.scalar_one_or_none()
if sub_tier:
tier = sub_tier
elif settings.ENV == "dev":
tier = "power"
else:
tier = user.tier or "free"
return {"tier": tier}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
@@ -563,6 +680,26 @@ class MemoryMiddleware:
out.append(plaintext)
return out
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
if not tier_manager.check_feature(user_tier, "relational_memory"):
return []
result = await self._db.execute(
select(MemoryRelation)
.where(MemoryRelation.user_id == user_id)
.order_by(MemoryRelation.confidence.desc())
.limit(10)
)
rows = result.scalars().all()
out = [
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
for r in rows
]
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive)