PHASE 3 — relational tier (Mem0g-light)

This commit is contained in:
Roberto Musso
2026-04-17 17:04:27 +02:00
parent 741b9b87fb
commit 341ee140e5
10 changed files with 850 additions and 33 deletions

View File

@@ -0,0 +1,74 @@
"""Add memory_relations table (Phase 3 — relational tier).
Revision ID: 006
Revises: 1f5975a4f3f4
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "006"
down_revision: Union[str, None] = "1f5975a4f3f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"memory_relations",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("subject_label", sa.String(128), nullable=False),
sa.Column("subject_type", sa.String(32), nullable=False),
sa.Column("predicate", sa.String(64), nullable=False),
sa.Column("object_label", sa.String(128), nullable=False),
sa.Column("object_type", sa.String(32), nullable=False),
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
sa.Column(
"source_episode_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"memory_relations_user_subject_idx",
"memory_relations",
["user_id", "subject_label"],
)
op.create_index(
"memory_relations_user_predicate_idx",
"memory_relations",
["user_id", "predicate"],
)
def downgrade() -> None:
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
op.drop_table("memory_relations")

225
app/api/routes/memory.py Normal file
View File

@@ -0,0 +1,225 @@
"""Memory management routes — view/edit/delete user memory tiers.
All routes require authentication. Data is always user-scoped.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.models import (
ExtractionQueue,
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
MemoryRelation,
)
from app.schemas import UserProfile
router = APIRouter(prefix="/memory", tags=["memory"])
logger = logging.getLogger(__name__)
_ALLOWED_PREDICATES = {
"works_at",
"reports_to",
"stakeholder_of",
"last_contacted_on",
"owes_followup",
"manages",
"collaborates_with",
"owns",
"member_of",
"custom",
}
# ── Response schemas ─────────────────────────────────────────────────────────
class RelationOut(BaseModel):
id: str
subject_label: str
subject_type: str
predicate: str
object_label: str
object_type: str
confidence: float
last_confirmed_at: int | None = None # epoch ms
class RelationPatch(BaseModel):
subject_label: str | None = None
object_label: str | None = None
predicate: str | None = None
confidence: float | None = Field(None, ge=0.0, le=1.0)
class CoreAddBody(BaseModel):
key: str = Field(..., min_length=1, max_length=255)
value: str = Field(..., min_length=1)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _relation_to_out(row: MemoryRelation) -> RelationOut:
last_ms: int | None = None
if row.last_confirmed_at is not None:
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
return RelationOut(
id=row.id,
subject_label=row.subject_label,
subject_type=row.subject_type,
predicate=row.predicate,
object_label=row.object_label,
object_type=row.object_type,
confidence=row.confidence,
last_confirmed_at=last_ms,
)
# ── Routes ───────────────────────────────────────────────────────────────────
@router.get("/core", response_model=dict[str, str])
async def get_core_memory(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""Return all core memory k/v pairs (plaintext) for the current user."""
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(current_user.id)
return {b["label"]: b["value"] for b in blocks}
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_core_key(
key: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Delete a single core memory key (GDPR Art. 17)."""
mw = MemoryMiddleware(db)
deleted = await mw.delete_core(current_user.id, key)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
async def add_core_key(
body: CoreAddBody,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""Add or overwrite a core memory key/value pair."""
mw = MemoryMiddleware(db)
await mw.update_core(current_user.id, body.key, body.value)
return {body.key: body.value}
@router.get("/relational", response_model=list[RelationOut])
async def get_relational_memory(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[RelationOut]:
"""Return all relational memory rows for the current user."""
mw = MemoryMiddleware(db)
rows = await mw.query_relations(current_user.id, limit=200)
return [_relation_to_out(r) for r in rows]
@router.patch("/relational/{relation_id}", response_model=RelationOut)
async def patch_relation(
relation_id: str,
body: RelationPatch,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> RelationOut:
"""Edit a relation row's labels, predicate, or confidence."""
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
)
result = await db.execute(
select(MemoryRelation).where(
MemoryRelation.id == relation_id,
MemoryRelation.user_id == current_user.id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
if body.subject_label is not None:
row.subject_label = body.subject_label
if body.object_label is not None:
row.object_label = body.object_label
if body.predicate is not None:
row.predicate = body.predicate
if body.confidence is not None:
row.confidence = body.confidence
row.last_confirmed_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(row)
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
return _relation_to_out(row)
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_relation(
relation_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Hard-delete a relation row (GDPR Art. 17)."""
result = await db.execute(
select(MemoryRelation).where(
MemoryRelation.id == relation_id,
MemoryRelation.user_id == current_user.id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
await db.delete(row)
await db.commit()
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
async def forget_all(
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Wipe all memory tiers for the current user (GDPR Art. 17).
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
"""
if x_confirm != "true":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
)
uid = current_user.id
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
await db.commit()
logger.warning("memory: forget_all GDPR wipe user=%s", uid)

View File

@@ -25,8 +25,9 @@ FEATURES: dict[str, dict[str, Any]] = {
"providers": 1,
"batch_builder": False,
"sso": False,
"real_embeddings": False, # keyword fallback only
"realtime_extraction": False, # batch queue (Phase 2)
"real_embeddings": False, # keyword fallback only
"realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+
},
"pro": {
"agents": -1, # unlimited
@@ -35,8 +36,9 @@ FEATURES: dict[str, dict[str, Any]] = {
"providers": -1,
"batch_builder": False,
"sso": False,
"real_embeddings": True, # pgvector cosine search
"realtime_extraction": True, # fire-and-forget asyncio.create_task
"real_embeddings": True, # pgvector cosine search
"realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates
},
"power": {
"agents": -1,
@@ -47,6 +49,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"sso": False,
"real_embeddings": True,
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
},
"team": {
"agents": -1,
@@ -57,6 +60,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"sso": True,
"real_embeddings": True,
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
},
}

View File

@@ -55,6 +55,22 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}."
)
def _relational_memory_injection(context: dict[str, Any]) -> str:
"""Return a system-prompt paragraph listing known people/projects from relational memory.
Returns empty string when no relational rows or tier is Free.
Capped at 800 chars to control token spend.
"""
relations: list[str] = context.get("relational_memory") or []
if not relations:
return ""
body = "\n".join(f"- {r}" for r in relations)
section = f"\n\nKnown people & projects:\n{body}"
if len(section) > 800:
section = section[:797] + "..."
return section
_HOME_SYSTEM_PROMPT = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
@@ -904,6 +920,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
response = await _run_single_agent(
user_id=user_id,
@@ -922,6 +939,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
response = await _run_single_agent(
user_id=user_id,
@@ -946,6 +964,7 @@ async def run_home_stream(
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
@@ -979,6 +998,7 @@ async def run_floating_stream(
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _language_instruction(context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False

View File

@@ -366,7 +366,7 @@ async def _apply_candidate(
if candidate.target_tier == "relational":
# Always upsert relations — decide_action skipped (no neighbour search).
if candidate.subject and candidate.predicate and candidate.object:
await _upsert_relation_stub(
await _upsert_relation(
middleware, db, user_id, candidate, trace_id
)
return
@@ -396,35 +396,29 @@ def _content_to_key(content: str) -> str:
return slug or "memory"
async def _upsert_relation_stub(
async def _upsert_relation(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Stub: upsert_relation will be fully wired in Phase 3.
Called here so Phase 2 extraction pipeline already routes relation candidates
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
"""
if hasattr(middleware, "upsert_relation"):
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject,
subject_type="unknown",
predicate=candidate.predicate,
object_=candidate.object,
object_type="unknown",
confidence=candidate.confidence,
)
else:
logger.info(
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject or "unknown",
subject_type="unknown",
predicate=candidate.predicate or "related_to",
object_=candidate.object or "unknown",
object_type="unknown",
confidence=candidate.confidence,
)
logger.info(
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
async def _store_proactive_stub(

View File

@@ -0,0 +1,102 @@
"""Memory maintenance jobs — Phase 3/5.
Two entrypoints called by the scheduler (APScheduler) registered in app/main.py:
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
Both are safe to call manually or from tests; they never raise.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import MemoryRelation
logger = logging.getLogger(__name__)
# Decay parameters
_DECAY_FACTOR = 0.95 # multiply confidence by this every _DECAY_PERIOD days
_DECAY_PERIOD_DAYS = 30 # period for one decay step
_PRUNE_THRESHOLD = 0.2 # rows below this confidence are deleted
async def decay_relations(db: AsyncSession, user_id: str) -> None:
"""Apply confidence decay to all relation rows for a user.
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
Rows whose confidence falls below 0.2 are deleted.
Never raises — wraps in try/except.
"""
try:
await _decay_relations_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
now = datetime.now(timezone.utc)
deleted = 0
decayed = 0
for row in rows:
reference = row.last_confirmed_at or row.created_at
if reference is None:
continue
# Ensure timezone-aware comparison
if reference.tzinfo is None:
reference = reference.replace(tzinfo=timezone.utc)
days_elapsed = (now - reference).days
if days_elapsed < _DECAY_PERIOD_DAYS:
continue
periods = days_elapsed // _DECAY_PERIOD_DAYS
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
if new_confidence < _PRUNE_THRESHOLD:
await db.delete(row)
deleted += 1
logger.info(
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
"confidence=%.3f (below threshold)",
row.id, user_id, row.subject_label, row.predicate, new_confidence,
)
else:
row.confidence = new_confidence
decayed += 1
try:
await db.commit()
logger.info(
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
user_id, decayed, deleted,
)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
await db.rollback()
async def drain_extraction_queue(db: AsyncSession) -> None:
"""Process pending ExtractionQueue rows for Free-tier users (Phase 5 stub).
Full implementation wired in Phase 5 when APScheduler is registered.
Currently logs count and returns.
"""
try:
from app.models import ExtractionQueue # noqa: PLC0415
result = await db.execute(select(ExtractionQueue))
rows = result.scalars().all()
logger.info("memory_maintenance: drain_extraction_queue pending=%d (Phase 5 cron)", len(rows))
except Exception as exc:
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)

View File

@@ -21,6 +21,7 @@ from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
@@ -33,11 +34,17 @@ from app.models import (
MemoryCore,
MemoryEpisodic,
MemoryProactive,
MemoryRelation,
User,
)
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
@@ -66,6 +73,7 @@ class MemoryMiddleware:
associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold)
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
@@ -78,9 +86,10 @@ class MemoryMiddleware:
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
relational = await self._load_relational(user_id, user_tier=user_tier)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
trace_id or "-",
user_id,
user_tier,
@@ -88,6 +97,7 @@ class MemoryMiddleware:
len(associative),
len(episodic),
len(proactive),
len(relational),
)
return {
@@ -95,6 +105,7 @@ class MemoryMiddleware:
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
"relational_memory": relational,
}
async def store_episode(
@@ -375,6 +386,99 @@ class MemoryMiddleware:
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def upsert_relation(
self,
user_id: str,
subject: str,
subject_type: str,
predicate: str,
object_: str,
object_type: str,
*,
confidence: float = 0.7,
source_episode_id: str | None = None,
notes: str | None = None,
) -> None:
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
subject_label / object_label are plaintext entity identifiers — not encrypted.
notes is optional; encrypted with user Fernet if provided.
"""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
user_dbg = await self._get_user_debug(user_id)
user_tier = user_dbg.get("tier") or "free"
if not tier_manager.check_feature(user_tier, "relational_memory"):
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
return
notes_encrypted: bytes | None = None
if notes:
fernet = await self._get_fernet(user_id)
if fernet:
notes_encrypted = fernet.encrypt(notes.encode())
result = await self._db.execute(
select(MemoryRelation).where(
MemoryRelation.user_id == user_id,
MemoryRelation.subject_label == subject,
MemoryRelation.predicate == predicate,
MemoryRelation.object_label == object_,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.subject_type = subject_type
existing.object_type = object_type
existing.confidence = confidence
existing.last_confirmed_at = _now()
if notes_encrypted is not None:
existing.notes_encrypted = notes_encrypted
else:
self._db.add(MemoryRelation(
id=str(uuid.uuid4()),
user_id=user_id,
subject_label=subject,
subject_type=subject_type,
predicate=predicate,
object_label=object_,
object_type=object_type,
confidence=confidence,
source_episode_id=source_episode_id,
notes_encrypted=notes_encrypted,
))
try:
await self._db.commit()
logger.info(
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
user_id, subject, predicate, object_,
)
except Exception as exc:
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def query_relations(
self,
user_id: str,
subject: str | None = None,
predicate: str | None = None,
object_: str | None = None,
limit: int = 20,
) -> list[MemoryRelation]:
"""Query relation rows for a user with optional filters."""
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
if subject is not None:
q = q.where(MemoryRelation.subject_label == subject)
if predicate is not None:
q = q.where(MemoryRelation.predicate == predicate)
if object_ is not None:
q = q.where(MemoryRelation.object_label == object_)
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
result = await self._db.execute(q)
return list(result.scalars().all())
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
@@ -463,13 +567,26 @@ class MemoryMiddleware:
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
from app.config.settings import settings # noqa: PLC0415
from app.models import Subscription # noqa: PLC0415
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
sub_result = await self._db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
sub_tier: str | None = sub_result.scalar_one_or_none()
if sub_tier:
tier = sub_tier
elif settings.ENV == "dev":
tier = "power"
else:
tier = user.tier or "free"
return {"tier": tier}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
@@ -563,6 +680,26 @@ class MemoryMiddleware:
out.append(plaintext)
return out
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
if not tier_manager.check_feature(user_tier, "relational_memory"):
return []
result = await self._db.execute(
select(MemoryRelation)
.where(MemoryRelation.user_id == user_id)
.order_by(MemoryRelation.confidence.desc())
.limit(10)
)
rows = result.scalars().all()
out = [
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
for r in rows
]
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive)

View File

@@ -50,13 +50,14 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, billing, chat, device_ws
from app.api.routes import agents, auth, billing, chat, device_ws, memory
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:

View File

@@ -14,6 +14,7 @@ Table inventory:
memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted)
memory_proactive — per-user behavioral patterns (encrypted)
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
"""
from __future__ import annotations
@@ -30,6 +31,7 @@ from sqlalchemy import (
ForeignKey,
Integer,
JSON,
LargeBinary,
String,
Text,
Uuid,
@@ -373,6 +375,44 @@ class ExtractionQueue(Base):
)
class MemoryRelation(Base):
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
subject_label/object_label are plaintext entity identifiers (not user content).
notes_encrypted is optional Fernet-encrypted per-user commentary.
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
"""
__tablename__ = "memory_relations"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
source_episode_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False),
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
nullable=True,
)
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
last_confirmed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class Plugin(Base):
"""Plugin marketplace catalog entry."""

View File

@@ -0,0 +1,220 @@
"""Tests for Phase 3 — relational tier (Mem0g-light).
Coverage:
1. upsert_relation inserts a row and query_relations returns it
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
3. tier gating: Free user gets empty list from query_relations + enrich_context
4. enrich_context includes relational_memory key for Pro user
5. decay_relations decays confidence and prunes rows below threshold
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_maintenance import decay_relations
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import MemoryRelation, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
FREE_USER_ID = TEST_USER_IDS["free"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest_asyncio.fixture
async def pro_user_with_key(db_session):
"""Set encryption_key on the pro test user so Fernet works."""
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
@pytest_asyncio.fixture
async def free_user_with_key(db_session):
"""Set encryption_key on the free test user."""
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
# ── Tests ─────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
"""upsert_relation inserts a row; query_relations returns it."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Giulia",
subject_type="person",
predicate="works_at",
object_="Acme Corp",
object_type="company",
confidence=0.9,
)
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
assert len(rows) == 1
assert rows[0].subject_label == "Giulia"
assert rows[0].predicate == "works_at"
assert rows[0].object_label == "Acme Corp"
assert abs(rows[0].confidence - 0.9) < 0.001
@pytest.mark.asyncio
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
"""Second upsert on same triple updates confidence and last_confirmed_at."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Marco",
subject_type="person",
predicate="stakeholder_of",
object_="Project Nexus",
object_type="project",
confidence=0.7,
)
await mm.upsert_relation(
PRO_USER_ID,
subject="Marco",
subject_type="person",
predicate="stakeholder_of",
object_="Project Nexus",
object_type="project",
confidence=0.95,
)
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
# Only one row despite two upserts
assert len(rows) == 1
assert abs(rows[0].confidence - 0.95) < 0.001
assert rows[0].last_confirmed_at is not None
@pytest.mark.asyncio
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
"""Free user: upsert_relation is silently skipped (no row created)."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
FREE_USER_ID,
subject="Alice",
subject_type="person",
predicate="reports_to",
object_="Bob",
object_type="person",
confidence=0.8,
)
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
assert rows == []
@pytest.mark.asyncio
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
"""enrich_context includes relational_memory key for Pro user."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Elena",
subject_type="person",
predicate="cfo_of",
object_="StartupXYZ",
object_type="company",
confidence=0.85,
)
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
assert "relational_memory" in ctx
assert any("Elena" in r for r in ctx["relational_memory"])
@pytest.mark.asyncio
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
"""Free user: relational_memory is empty list in enrich_context."""
mm = MemoryMiddleware(db_session)
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
assert ctx.get("relational_memory") == []
@pytest.mark.asyncio
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
"""decay_relations reduces confidence on stale rows."""
old_date = datetime.now(timezone.utc) - timedelta(days=35)
row = MemoryRelation(
id=str(uuid.uuid4()),
user_id=PRO_USER_ID,
subject_label="OldContact",
subject_type="person",
predicate="knows",
object_label="SomeProject",
object_type="project",
confidence=0.8,
last_confirmed_at=old_date,
)
db_session.add(row)
await db_session.commit()
await decay_relations(db_session, PRO_USER_ID)
result = await db_session.execute(
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
)
updated = result.scalar_one_or_none()
assert updated is not None
assert updated.confidence < 0.8
@pytest.mark.asyncio
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
old_date = datetime.now(timezone.utc) - timedelta(days=65)
row = MemoryRelation(
id=str(uuid.uuid4()),
user_id=PRO_USER_ID,
subject_label="ExpiredContact",
subject_type="person",
predicate="used_to_work_with",
object_label="OldCorp",
object_type="company",
confidence=0.21,
last_confirmed_at=old_date,
)
db_session.add(row)
await db_session.commit()
await decay_relations(db_session, PRO_USER_ID)
result = await db_session.execute(
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
)
pruned = result.scalar_one_or_none()
assert pruned is None