PHASE 3 — relational tier (Mem0g-light)
This commit is contained in:
74
alembic/versions/006_memory_relations.py
Normal file
74
alembic/versions/006_memory_relations.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Add memory_relations table (Phase 3 — relational tier).
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 1f5975a4f3f4
|
||||
Create Date: 2026-04-16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "006"
|
||||
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"memory_relations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||
sa.Column("predicate", sa.String(64), nullable=False),
|
||||
sa.Column("object_label", sa.String(128), nullable=False),
|
||||
sa.Column("object_type", sa.String(32), nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||
sa.Column(
|
||||
"source_episode_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_subject_idx",
|
||||
"memory_relations",
|
||||
["user_id", "subject_label"],
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_predicate_idx",
|
||||
"memory_relations",
|
||||
["user_id", "predicate"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||
op.drop_table("memory_relations")
|
||||
225
app/api/routes/memory.py
Normal file
225
app/api/routes/memory.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Memory management routes — view/edit/delete user memory tiers.
|
||||
|
||||
All routes require authentication. Data is always user-scoped.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
)
|
||||
from app.schemas import UserProfile
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_PREDICATES = {
|
||||
"works_at",
|
||||
"reports_to",
|
||||
"stakeholder_of",
|
||||
"last_contacted_on",
|
||||
"owes_followup",
|
||||
"manages",
|
||||
"collaborates_with",
|
||||
"owns",
|
||||
"member_of",
|
||||
"custom",
|
||||
}
|
||||
|
||||
|
||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class RelationOut(BaseModel):
|
||||
id: str
|
||||
subject_label: str
|
||||
subject_type: str
|
||||
predicate: str
|
||||
object_label: str
|
||||
object_type: str
|
||||
confidence: float
|
||||
last_confirmed_at: int | None = None # epoch ms
|
||||
|
||||
|
||||
class RelationPatch(BaseModel):
|
||||
subject_label: str | None = None
|
||||
object_label: str | None = None
|
||||
predicate: str | None = None
|
||||
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class CoreAddBody(BaseModel):
|
||||
key: str = Field(..., min_length=1, max_length=255)
|
||||
value: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||
last_ms: int | None = None
|
||||
if row.last_confirmed_at is not None:
|
||||
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||
return RelationOut(
|
||||
id=row.id,
|
||||
subject_label=row.subject_label,
|
||||
subject_type=row.subject_type,
|
||||
predicate=row.predicate,
|
||||
object_label=row.object_label,
|
||||
object_type=row.object_type,
|
||||
confidence=row.confidence,
|
||||
last_confirmed_at=last_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/core", response_model=dict[str, str])
|
||||
async def get_core_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(current_user.id)
|
||||
return {b["label"]: b["value"] for b in blocks}
|
||||
|
||||
|
||||
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_core_key(
|
||||
key: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||
mw = MemoryMiddleware(db)
|
||||
deleted = await mw.delete_core(current_user.id, key)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||
|
||||
|
||||
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||
async def add_core_key(
|
||||
body: CoreAddBody,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Add or overwrite a core memory key/value pair."""
|
||||
mw = MemoryMiddleware(db)
|
||||
await mw.update_core(current_user.id, body.key, body.value)
|
||||
return {body.key: body.value}
|
||||
|
||||
|
||||
@router.get("/relational", response_model=list[RelationOut])
|
||||
async def get_relational_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[RelationOut]:
|
||||
"""Return all relational memory rows for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
rows = await mw.query_relations(current_user.id, limit=200)
|
||||
return [_relation_to_out(r) for r in rows]
|
||||
|
||||
|
||||
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||
async def patch_relation(
|
||||
relation_id: str,
|
||||
body: RelationPatch,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> RelationOut:
|
||||
"""Edit a relation row's labels, predicate, or confidence."""
|
||||
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
|
||||
if body.subject_label is not None:
|
||||
row.subject_label = body.subject_label
|
||||
if body.object_label is not None:
|
||||
row.object_label = body.object_label
|
||||
if body.predicate is not None:
|
||||
row.predicate = body.predicate
|
||||
if body.confidence is not None:
|
||||
row.confidence = body.confidence
|
||||
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
return _relation_to_out(row)
|
||||
|
||||
|
||||
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_relation(
|
||||
relation_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
|
||||
|
||||
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def forget_all(
|
||||
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||
|
||||
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||
"""
|
||||
if x_confirm != "true":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||
)
|
||||
|
||||
uid = current_user.id
|
||||
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||
await db.commit()
|
||||
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||
@@ -25,8 +25,9 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -35,8 +36,9 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"relational_memory": True, # person/project predicates
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -47,6 +49,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -57,6 +60,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +55,22 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
||||
f"All your output text must be written in {lang}."
|
||||
)
|
||||
|
||||
def _relational_memory_injection(context: dict[str, Any]) -> str:
|
||||
"""Return a system-prompt paragraph listing known people/projects from relational memory.
|
||||
|
||||
Returns empty string when no relational rows or tier is Free.
|
||||
Capped at 800 chars to control token spend.
|
||||
"""
|
||||
relations: list[str] = context.get("relational_memory") or []
|
||||
if not relations:
|
||||
return ""
|
||||
body = "\n".join(f"- {r}" for r in relations)
|
||||
section = f"\n\nKnown people & projects:\n{body}"
|
||||
if len(section) > 800:
|
||||
section = section[:797] + "..."
|
||||
return section
|
||||
|
||||
|
||||
_HOME_SYSTEM_PROMPT = (
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
@@ -904,6 +920,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
@@ -922,6 +939,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
@@ -946,6 +964,7 @@ async def run_home_stream(
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
@@ -979,6 +998,7 @@ async def run_floating_stream(
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
|
||||
@@ -366,7 +366,7 @@ async def _apply_candidate(
|
||||
if candidate.target_tier == "relational":
|
||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||
if candidate.subject and candidate.predicate and candidate.object:
|
||||
await _upsert_relation_stub(
|
||||
await _upsert_relation(
|
||||
middleware, db, user_id, candidate, trace_id
|
||||
)
|
||||
return
|
||||
@@ -396,35 +396,29 @@ def _content_to_key(content: str) -> str:
|
||||
return slug or "memory"
|
||||
|
||||
|
||||
async def _upsert_relation_stub(
|
||||
async def _upsert_relation(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Stub: upsert_relation will be fully wired in Phase 3.
|
||||
|
||||
Called here so Phase 2 extraction pipeline already routes relation candidates
|
||||
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
|
||||
"""
|
||||
if hasattr(middleware, "upsert_relation"):
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject,
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate,
|
||||
object_=candidate.object,
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject or "unknown",
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate or "related_to",
|
||||
object_=candidate.object or "unknown",
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
logger.info(
|
||||
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
|
||||
|
||||
async def _store_proactive_stub(
|
||||
|
||||
102
app/core/memory_maintenance.py
Normal file
102
app/core/memory_maintenance.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Memory maintenance jobs — Phase 3/5.
|
||||
|
||||
Two entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||
|
||||
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||
|
||||
Both are safe to call manually or from tests; they never raise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import MemoryRelation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Decay parameters
|
||||
_DECAY_FACTOR = 0.95 # multiply confidence by this every _DECAY_PERIOD days
|
||||
_DECAY_PERIOD_DAYS = 30 # period for one decay step
|
||||
_PRUNE_THRESHOLD = 0.2 # rows below this confidence are deleted
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
"""Apply confidence decay to all relation rows for a user.
|
||||
|
||||
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||
Rows whose confidence falls below 0.2 are deleted.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _decay_relations_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.last_confirmed_at or row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
# Ensure timezone-aware comparison
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||
"confidence=%.3f (below threshold)",
|
||||
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||
)
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||
"""Process pending ExtractionQueue rows for Free-tier users (Phase 5 stub).
|
||||
|
||||
Full implementation wired in Phase 5 when APScheduler is registered.
|
||||
Currently logs count and returns.
|
||||
"""
|
||||
try:
|
||||
from app.models import ExtractionQueue # noqa: PLC0415
|
||||
result = await db.execute(select(ExtractionQueue))
|
||||
rows = result.scalars().all()
|
||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d (Phase 5 cron)", len(rows))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
@@ -33,11 +34,17 @@ from app.models import (
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Tuning constants
|
||||
_ASSOCIATIVE_TOP_K = 5
|
||||
_EPISODIC_RECENT_N = 10
|
||||
@@ -66,6 +73,7 @@ class MemoryMiddleware:
|
||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
@@ -78,9 +86,10 @@ class MemoryMiddleware:
|
||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_tier,
|
||||
@@ -88,6 +97,7 @@ class MemoryMiddleware:
|
||||
len(associative),
|
||||
len(episodic),
|
||||
len(proactive),
|
||||
len(relational),
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -95,6 +105,7 @@ class MemoryMiddleware:
|
||||
"associative_memory": associative,
|
||||
"episodic_memory": episodic,
|
||||
"proactive_hints": proactive,
|
||||
"relational_memory": relational,
|
||||
}
|
||||
|
||||
async def store_episode(
|
||||
@@ -375,6 +386,99 @@ class MemoryMiddleware:
|
||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def upsert_relation(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str,
|
||||
subject_type: str,
|
||||
predicate: str,
|
||||
object_: str,
|
||||
object_type: str,
|
||||
*,
|
||||
confidence: float = 0.7,
|
||||
source_episode_id: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> None:
|
||||
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||
|
||||
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||
notes is optional; encrypted with user Fernet if provided.
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||
return
|
||||
|
||||
notes_encrypted: bytes | None = None
|
||||
if notes:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet:
|
||||
notes_encrypted = fernet.encrypt(notes.encode())
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.user_id == user_id,
|
||||
MemoryRelation.subject_label == subject,
|
||||
MemoryRelation.predicate == predicate,
|
||||
MemoryRelation.object_label == object_,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.subject_type = subject_type
|
||||
existing.object_type = object_type
|
||||
existing.confidence = confidence
|
||||
existing.last_confirmed_at = _now()
|
||||
if notes_encrypted is not None:
|
||||
existing.notes_encrypted = notes_encrypted
|
||||
else:
|
||||
self._db.add(MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type=subject_type,
|
||||
predicate=predicate,
|
||||
object_label=object_,
|
||||
object_type=object_type,
|
||||
confidence=confidence,
|
||||
source_episode_id=source_episode_id,
|
||||
notes_encrypted=notes_encrypted,
|
||||
))
|
||||
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||
user_id, subject, predicate, object_,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def query_relations(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str | None = None,
|
||||
predicate: str | None = None,
|
||||
object_: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[MemoryRelation]:
|
||||
"""Query relation rows for a user with optional filters."""
|
||||
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
if subject is not None:
|
||||
q = q.where(MemoryRelation.subject_label == subject)
|
||||
if predicate is not None:
|
||||
q = q.where(MemoryRelation.predicate == predicate)
|
||||
if object_ is not None:
|
||||
q = q.where(MemoryRelation.object_label == object_)
|
||||
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||
result = await self._db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
@@ -463,13 +567,26 @@ class MemoryMiddleware:
|
||||
|
||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||
"""Load lightweight user debug fields for trace logs."""
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
return {"tier": None}
|
||||
return {
|
||||
"tier": user.tier,
|
||||
}
|
||||
|
||||
sub_result = await self._db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||
if sub_tier:
|
||||
tier = sub_tier
|
||||
elif settings.ENV == "dev":
|
||||
tier = "power"
|
||||
else:
|
||||
tier = user.tier or "free"
|
||||
|
||||
return {"tier": tier}
|
||||
|
||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||
result = await self._db.execute(
|
||||
@@ -563,6 +680,26 @@ class MemoryMiddleware:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation)
|
||||
.where(MemoryRelation.user_id == user_id)
|
||||
.order_by(MemoryRelation.confidence.desc())
|
||||
.limit(10)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out = [
|
||||
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||
for r in rows
|
||||
]
|
||||
return out
|
||||
|
||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryProactive)
|
||||
|
||||
@@ -50,13 +50,14 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import agents, auth, billing, chat, device_ws
|
||||
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
app.include_router(memory.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
|
||||
@@ -14,6 +14,7 @@ Table inventory:
|
||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||
memory_episodic — per-user session summaries (encrypted)
|
||||
memory_proactive — per-user behavioral patterns (encrypted)
|
||||
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -30,6 +31,7 @@ from sqlalchemy import (
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
LargeBinary,
|
||||
String,
|
||||
Text,
|
||||
Uuid,
|
||||
@@ -373,6 +375,44 @@ class ExtractionQueue(Base):
|
||||
)
|
||||
|
||||
|
||||
class MemoryRelation(Base):
|
||||
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
||||
|
||||
subject_label/object_label are plaintext entity identifiers (not user content).
|
||||
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
||||
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_relations"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
||||
source_episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False),
|
||||
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin marketplace catalog entry."""
|
||||
|
||||
|
||||
220
tests/test_memory_relations.py
Normal file
220
tests/test_memory_relations.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for Phase 3 — relational tier (Mem0g-light).
|
||||
|
||||
Coverage:
|
||||
1. upsert_relation inserts a row and query_relations returns it
|
||||
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
|
||||
3. tier gating: Free user gets empty list from query_relations + enrich_context
|
||||
4. enrich_context includes relational_memory key for Pro user
|
||||
5. decay_relations decays confidence and prunes rows below threshold
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_maintenance import decay_relations
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryRelation, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user_with_key(db_session):
|
||||
"""Set encryption_key on the pro test user so Fernet works."""
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def free_user_with_key(db_session):
|
||||
"""Set encryption_key on the free test user."""
|
||||
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
|
||||
"""upsert_relation inserts a row; query_relations returns it."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Giulia",
|
||||
subject_type="person",
|
||||
predicate="works_at",
|
||||
object_="Acme Corp",
|
||||
object_type="company",
|
||||
confidence=0.9,
|
||||
)
|
||||
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
|
||||
assert len(rows) == 1
|
||||
assert rows[0].subject_label == "Giulia"
|
||||
assert rows[0].predicate == "works_at"
|
||||
assert rows[0].object_label == "Acme Corp"
|
||||
assert abs(rows[0].confidence - 0.9) < 0.001
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
|
||||
"""Second upsert on same triple updates confidence and last_confirmed_at."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Marco",
|
||||
subject_type="person",
|
||||
predicate="stakeholder_of",
|
||||
object_="Project Nexus",
|
||||
object_type="project",
|
||||
confidence=0.7,
|
||||
)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Marco",
|
||||
subject_type="person",
|
||||
predicate="stakeholder_of",
|
||||
object_="Project Nexus",
|
||||
object_type="project",
|
||||
confidence=0.95,
|
||||
)
|
||||
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
|
||||
# Only one row despite two upserts
|
||||
assert len(rows) == 1
|
||||
assert abs(rows[0].confidence - 0.95) < 0.001
|
||||
assert rows[0].last_confirmed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
|
||||
"""Free user: upsert_relation is silently skipped (no row created)."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
FREE_USER_ID,
|
||||
subject="Alice",
|
||||
subject_type="person",
|
||||
predicate="reports_to",
|
||||
object_="Bob",
|
||||
object_type="person",
|
||||
confidence=0.8,
|
||||
)
|
||||
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
|
||||
assert rows == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
|
||||
"""enrich_context includes relational_memory key for Pro user."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Elena",
|
||||
subject_type="person",
|
||||
predicate="cfo_of",
|
||||
object_="StartupXYZ",
|
||||
object_type="company",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
|
||||
|
||||
assert "relational_memory" in ctx
|
||||
assert any("Elena" in r for r in ctx["relational_memory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
|
||||
"""Free user: relational_memory is empty list in enrich_context."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
|
||||
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
|
||||
|
||||
assert ctx.get("relational_memory") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
|
||||
"""decay_relations reduces confidence on stale rows."""
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35)
|
||||
row = MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=PRO_USER_ID,
|
||||
subject_label="OldContact",
|
||||
subject_type="person",
|
||||
predicate="knows",
|
||||
object_label="SomeProject",
|
||||
object_type="project",
|
||||
confidence=0.8,
|
||||
last_confirmed_at=old_date,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
await decay_relations(db_session, PRO_USER_ID)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
|
||||
)
|
||||
updated = result.scalar_one_or_none()
|
||||
assert updated is not None
|
||||
assert updated.confidence < 0.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
|
||||
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
|
||||
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=65)
|
||||
row = MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=PRO_USER_ID,
|
||||
subject_label="ExpiredContact",
|
||||
subject_type="person",
|
||||
predicate="used_to_work_with",
|
||||
object_label="OldCorp",
|
||||
object_type="company",
|
||||
confidence=0.21,
|
||||
last_confirmed_at=old_date,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
await decay_relations(db_session, PRO_USER_ID)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
|
||||
)
|
||||
pruned = result.scalar_one_or_none()
|
||||
assert pruned is None
|
||||
Reference in New Issue
Block a user