PHASE 3 — relational tier (Mem0g-light)
This commit is contained in:
74
alembic/versions/006_memory_relations.py
Normal file
74
alembic/versions/006_memory_relations.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Add memory_relations table (Phase 3 — relational tier).
|
||||||
|
|
||||||
|
Revision ID: 006
|
||||||
|
Revises: 1f5975a4f3f4
|
||||||
|
Create Date: 2026-04-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "006"
|
||||||
|
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"memory_relations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||||
|
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("predicate", sa.String(64), nullable=False),
|
||||||
|
sa.Column("object_label", sa.String(128), nullable=False),
|
||||||
|
sa.Column("object_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||||
|
sa.Column(
|
||||||
|
"source_episode_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"memory_relations_user_subject_idx",
|
||||||
|
"memory_relations",
|
||||||
|
["user_id", "subject_label"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"memory_relations_user_predicate_idx",
|
||||||
|
"memory_relations",
|
||||||
|
["user_id", "predicate"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||||
|
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||||
|
op.drop_table("memory_relations")
|
||||||
225
app/api/routes/memory.py
Normal file
225
app/api/routes/memory.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
"""Memory management routes — view/edit/delete user memory tiers.
|
||||||
|
|
||||||
|
All routes require authentication. Data is always user-scoped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import (
|
||||||
|
ExtractionQueue,
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
MemoryRelation,
|
||||||
|
)
|
||||||
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ALLOWED_PREDICATES = {
|
||||||
|
"works_at",
|
||||||
|
"reports_to",
|
||||||
|
"stakeholder_of",
|
||||||
|
"last_contacted_on",
|
||||||
|
"owes_followup",
|
||||||
|
"manages",
|
||||||
|
"collaborates_with",
|
||||||
|
"owns",
|
||||||
|
"member_of",
|
||||||
|
"custom",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RelationOut(BaseModel):
|
||||||
|
id: str
|
||||||
|
subject_label: str
|
||||||
|
subject_type: str
|
||||||
|
predicate: str
|
||||||
|
object_label: str
|
||||||
|
object_type: str
|
||||||
|
confidence: float
|
||||||
|
last_confirmed_at: int | None = None # epoch ms
|
||||||
|
|
||||||
|
|
||||||
|
class RelationPatch(BaseModel):
|
||||||
|
subject_label: str | None = None
|
||||||
|
object_label: str | None = None
|
||||||
|
predicate: str | None = None
|
||||||
|
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreAddBody(BaseModel):
|
||||||
|
key: str = Field(..., min_length=1, max_length=255)
|
||||||
|
value: str = Field(..., min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||||
|
last_ms: int | None = None
|
||||||
|
if row.last_confirmed_at is not None:
|
||||||
|
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||||
|
return RelationOut(
|
||||||
|
id=row.id,
|
||||||
|
subject_label=row.subject_label,
|
||||||
|
subject_type=row.subject_type,
|
||||||
|
predicate=row.predicate,
|
||||||
|
object_label=row.object_label,
|
||||||
|
object_type=row.object_type,
|
||||||
|
confidence=row.confidence,
|
||||||
|
last_confirmed_at=last_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/core", response_model=dict[str, str])
|
||||||
|
async def get_core_memory(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(current_user.id)
|
||||||
|
return {b["label"]: b["value"] for b in blocks}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_core_key(
|
||||||
|
key: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
deleted = await mw.delete_core(current_user.id, key)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||||
|
async def add_core_key(
|
||||||
|
body: CoreAddBody,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Add or overwrite a core memory key/value pair."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
await mw.update_core(current_user.id, body.key, body.value)
|
||||||
|
return {body.key: body.value}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/relational", response_model=list[RelationOut])
|
||||||
|
async def get_relational_memory(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[RelationOut]:
|
||||||
|
"""Return all relational memory rows for the current user."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
rows = await mw.query_relations(current_user.id, limit=200)
|
||||||
|
return [_relation_to_out(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||||
|
async def patch_relation(
|
||||||
|
relation_id: str,
|
||||||
|
body: RelationPatch,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> RelationOut:
|
||||||
|
"""Edit a relation row's labels, predicate, or confidence."""
|
||||||
|
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.id == relation_id,
|
||||||
|
MemoryRelation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||||
|
|
||||||
|
if body.subject_label is not None:
|
||||||
|
row.subject_label = body.subject_label
|
||||||
|
if body.object_label is not None:
|
||||||
|
row.object_label = body.object_label
|
||||||
|
if body.predicate is not None:
|
||||||
|
row.predicate = body.predicate
|
||||||
|
if body.confidence is not None:
|
||||||
|
row.confidence = body.confidence
|
||||||
|
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(row)
|
||||||
|
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||||
|
return _relation_to_out(row)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_relation(
|
||||||
|
relation_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.id == relation_id,
|
||||||
|
MemoryRelation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||||
|
await db.delete(row)
|
||||||
|
await db.commit()
|
||||||
|
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def forget_all(
|
||||||
|
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||||
|
|
||||||
|
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||||
|
"""
|
||||||
|
if x_confirm != "true":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||||
|
)
|
||||||
|
|
||||||
|
uid = current_user.id
|
||||||
|
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||||
|
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||||
|
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||||
|
await db.commit()
|
||||||
|
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||||
@@ -25,8 +25,9 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
"real_embeddings": False, # keyword fallback only
|
"real_embeddings": False, # keyword fallback only
|
||||||
"realtime_extraction": False, # batch queue (Phase 2)
|
"realtime_extraction": False, # batch queue (Phase 2)
|
||||||
|
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
@@ -35,8 +36,9 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
"real_embeddings": True, # pgvector cosine search
|
"real_embeddings": True, # pgvector cosine search
|
||||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||||
|
"relational_memory": True, # person/project predicates
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -47,6 +49,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"sso": False,
|
"sso": False,
|
||||||
"real_embeddings": True,
|
"real_embeddings": True,
|
||||||
"realtime_extraction": True,
|
"realtime_extraction": True,
|
||||||
|
"relational_memory": True, # all predicates incl. custom
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -57,6 +60,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"sso": True,
|
"sso": True,
|
||||||
"real_embeddings": True,
|
"real_embeddings": True,
|
||||||
"realtime_extraction": True,
|
"realtime_extraction": True,
|
||||||
|
"relational_memory": True, # all predicates incl. custom
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,22 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
|||||||
f"All your output text must be written in {lang}."
|
f"All your output text must be written in {lang}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _relational_memory_injection(context: dict[str, Any]) -> str:
|
||||||
|
"""Return a system-prompt paragraph listing known people/projects from relational memory.
|
||||||
|
|
||||||
|
Returns empty string when no relational rows or tier is Free.
|
||||||
|
Capped at 800 chars to control token spend.
|
||||||
|
"""
|
||||||
|
relations: list[str] = context.get("relational_memory") or []
|
||||||
|
if not relations:
|
||||||
|
return ""
|
||||||
|
body = "\n".join(f"- {r}" for r in relations)
|
||||||
|
section = f"\n\nKnown people & projects:\n{body}"
|
||||||
|
if len(section) > 800:
|
||||||
|
section = section[:797] + "..."
|
||||||
|
return section
|
||||||
|
|
||||||
|
|
||||||
_HOME_SYSTEM_PROMPT = (
|
_HOME_SYSTEM_PROMPT = (
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
@@ -904,6 +920,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -922,6 +939,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -946,6 +964,7 @@ async def run_home_stream(
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
text_chunks: list[str] = []
|
text_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
@@ -979,6 +998,7 @@ async def run_floating_stream(
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
system_prompt += _language_instruction(context)
|
system_prompt += _language_instruction(context)
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
emitted_sanitized = False
|
emitted_sanitized = False
|
||||||
|
|||||||
@@ -366,7 +366,7 @@ async def _apply_candidate(
|
|||||||
if candidate.target_tier == "relational":
|
if candidate.target_tier == "relational":
|
||||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||||
if candidate.subject and candidate.predicate and candidate.object:
|
if candidate.subject and candidate.predicate and candidate.object:
|
||||||
await _upsert_relation_stub(
|
await _upsert_relation(
|
||||||
middleware, db, user_id, candidate, trace_id
|
middleware, db, user_id, candidate, trace_id
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -396,35 +396,29 @@ def _content_to_key(content: str) -> str:
|
|||||||
return slug or "memory"
|
return slug or "memory"
|
||||||
|
|
||||||
|
|
||||||
async def _upsert_relation_stub(
|
async def _upsert_relation(
|
||||||
middleware: Any,
|
middleware: Any,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
candidate: MemoryCandidate,
|
candidate: MemoryCandidate,
|
||||||
trace_id: str | None,
|
trace_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Stub: upsert_relation will be fully wired in Phase 3.
|
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||||
|
await middleware.upsert_relation(
|
||||||
Called here so Phase 2 extraction pipeline already routes relation candidates
|
user_id=user_id,
|
||||||
correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation().
|
subject=candidate.subject or "unknown",
|
||||||
"""
|
subject_type="unknown",
|
||||||
if hasattr(middleware, "upsert_relation"):
|
predicate=candidate.predicate or "related_to",
|
||||||
await middleware.upsert_relation(
|
object_=candidate.object or "unknown",
|
||||||
user_id=user_id,
|
object_type="unknown",
|
||||||
subject=candidate.subject,
|
confidence=candidate.confidence,
|
||||||
subject_type="unknown",
|
)
|
||||||
predicate=candidate.predicate,
|
logger.info(
|
||||||
object_=candidate.object,
|
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||||
object_type="unknown",
|
candidate.subject,
|
||||||
confidence=candidate.confidence,
|
candidate.predicate,
|
||||||
)
|
candidate.object,
|
||||||
else:
|
)
|
||||||
logger.info(
|
|
||||||
"memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s",
|
|
||||||
candidate.subject,
|
|
||||||
candidate.predicate,
|
|
||||||
candidate.object,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_proactive_stub(
|
async def _store_proactive_stub(
|
||||||
|
|||||||
102
app/core/memory_maintenance.py
Normal file
102
app/core/memory_maintenance.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Memory maintenance jobs — Phase 3/5.
|
||||||
|
|
||||||
|
Two entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||||
|
|
||||||
|
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||||
|
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||||
|
|
||||||
|
Both are safe to call manually or from tests; they never raise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import MemoryRelation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Decay parameters
|
||||||
|
_DECAY_FACTOR = 0.95 # multiply confidence by this every _DECAY_PERIOD days
|
||||||
|
_DECAY_PERIOD_DAYS = 30 # period for one decay step
|
||||||
|
_PRUNE_THRESHOLD = 0.2 # rows below this confidence are deleted
|
||||||
|
|
||||||
|
|
||||||
|
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Apply confidence decay to all relation rows for a user.
|
||||||
|
|
||||||
|
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||||
|
Rows whose confidence falls below 0.2 are deleted.
|
||||||
|
|
||||||
|
Never raises — wraps in try/except.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await _decay_relations_inner(db, user_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
deleted = 0
|
||||||
|
decayed = 0
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
reference = row.last_confirmed_at or row.created_at
|
||||||
|
if reference is None:
|
||||||
|
continue
|
||||||
|
# Ensure timezone-aware comparison
|
||||||
|
if reference.tzinfo is None:
|
||||||
|
reference = reference.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
days_elapsed = (now - reference).days
|
||||||
|
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||||
|
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||||
|
|
||||||
|
if new_confidence < _PRUNE_THRESHOLD:
|
||||||
|
await db.delete(row)
|
||||||
|
deleted += 1
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||||
|
"confidence=%.3f (below threshold)",
|
||||||
|
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
row.confidence = new_confidence
|
||||||
|
decayed += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||||
|
user_id, decayed, deleted,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||||
|
await db.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||||
|
"""Process pending ExtractionQueue rows for Free-tier users (Phase 5 stub).
|
||||||
|
|
||||||
|
Full implementation wired in Phase 5 when APScheduler is registered.
|
||||||
|
Currently logs count and returns.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.models import ExtractionQueue # noqa: PLC0415
|
||||||
|
result = await db.execute(select(ExtractionQueue))
|
||||||
|
rows = result.scalars().all()
|
||||||
|
logger.info("memory_maintenance: drain_extraction_queue pending=%d (Phase 5 cron)", len(rows))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
@@ -33,11 +34,17 @@ from app.models import (
|
|||||||
MemoryCore,
|
MemoryCore,
|
||||||
MemoryEpisodic,
|
MemoryEpisodic,
|
||||||
MemoryProactive,
|
MemoryProactive,
|
||||||
|
MemoryRelation,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
# Tuning constants
|
# Tuning constants
|
||||||
_ASSOCIATIVE_TOP_K = 5
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
_EPISODIC_RECENT_N = 10
|
_EPISODIC_RECENT_N = 10
|
||||||
@@ -66,6 +73,7 @@ class MemoryMiddleware:
|
|||||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -78,9 +86,10 @@ class MemoryMiddleware:
|
|||||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
user_id,
|
user_id,
|
||||||
user_tier,
|
user_tier,
|
||||||
@@ -88,6 +97,7 @@ class MemoryMiddleware:
|
|||||||
len(associative),
|
len(associative),
|
||||||
len(episodic),
|
len(episodic),
|
||||||
len(proactive),
|
len(proactive),
|
||||||
|
len(relational),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -95,6 +105,7 @@ class MemoryMiddleware:
|
|||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
"episodic_memory": episodic,
|
"episodic_memory": episodic,
|
||||||
"proactive_hints": proactive,
|
"proactive_hints": proactive,
|
||||||
|
"relational_memory": relational,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def store_episode(
|
async def store_episode(
|
||||||
@@ -375,6 +386,99 @@ class MemoryMiddleware:
|
|||||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def upsert_relation(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
subject: str,
|
||||||
|
subject_type: str,
|
||||||
|
predicate: str,
|
||||||
|
object_: str,
|
||||||
|
object_type: str,
|
||||||
|
*,
|
||||||
|
confidence: float = 0.7,
|
||||||
|
source_episode_id: str | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||||
|
|
||||||
|
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||||
|
notes is optional; encrypted with user Fernet if provided.
|
||||||
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||||
|
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||||
|
return
|
||||||
|
|
||||||
|
notes_encrypted: bytes | None = None
|
||||||
|
if notes:
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet:
|
||||||
|
notes_encrypted = fernet.encrypt(notes.encode())
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryRelation).where(
|
||||||
|
MemoryRelation.user_id == user_id,
|
||||||
|
MemoryRelation.subject_label == subject,
|
||||||
|
MemoryRelation.predicate == predicate,
|
||||||
|
MemoryRelation.object_label == object_,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing is not None:
|
||||||
|
existing.subject_type = subject_type
|
||||||
|
existing.object_type = object_type
|
||||||
|
existing.confidence = confidence
|
||||||
|
existing.last_confirmed_at = _now()
|
||||||
|
if notes_encrypted is not None:
|
||||||
|
existing.notes_encrypted = notes_encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
subject_label=subject,
|
||||||
|
subject_type=subject_type,
|
||||||
|
predicate=predicate,
|
||||||
|
object_label=object_,
|
||||||
|
object_type=object_type,
|
||||||
|
confidence=confidence,
|
||||||
|
source_episode_id=source_episode_id,
|
||||||
|
notes_encrypted=notes_encrypted,
|
||||||
|
))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||||
|
user_id, subject, predicate, object_,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def query_relations(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
subject: str | None = None,
|
||||||
|
predicate: str | None = None,
|
||||||
|
object_: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[MemoryRelation]:
|
||||||
|
"""Query relation rows for a user with optional filters."""
|
||||||
|
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||||
|
if subject is not None:
|
||||||
|
q = q.where(MemoryRelation.subject_label == subject)
|
||||||
|
if predicate is not None:
|
||||||
|
q = q.where(MemoryRelation.predicate == predicate)
|
||||||
|
if object_ is not None:
|
||||||
|
q = q.where(MemoryRelation.object_label == object_)
|
||||||
|
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||||
|
result = await self._db.execute(q)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
"""Insert a long-term archival memory entry."""
|
"""Insert a long-term archival memory entry."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
@@ -463,13 +567,26 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
"""Load lightweight user debug fields for trace logs."""
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
if user is None:
|
if user is None:
|
||||||
return {"tier": None}
|
return {"tier": None}
|
||||||
return {
|
|
||||||
"tier": user.tier,
|
sub_result = await self._db.execute(
|
||||||
}
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||||
|
if sub_tier:
|
||||||
|
tier = sub_tier
|
||||||
|
elif settings.ENV == "dev":
|
||||||
|
tier = "power"
|
||||||
|
else:
|
||||||
|
tier = user.tier or "free"
|
||||||
|
|
||||||
|
return {"tier": tier}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
@@ -563,6 +680,26 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||||
|
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryRelation)
|
||||||
|
.where(MemoryRelation.user_id == user_id)
|
||||||
|
.order_by(MemoryRelation.confidence.desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out = [
|
||||||
|
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return out
|
||||||
|
|
||||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryProactive)
|
select(MemoryProactive)
|
||||||
|
|||||||
@@ -50,13 +50,14 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, billing, chat, device_ws
|
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
app.include_router(memory.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ Table inventory:
|
|||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
memory_proactive — per-user behavioral patterns (encrypted)
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
|
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -30,6 +31,7 @@ from sqlalchemy import (
|
|||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
JSON,
|
JSON,
|
||||||
|
LargeBinary,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
Uuid,
|
Uuid,
|
||||||
@@ -373,6 +375,44 @@ class ExtractionQueue(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRelation(Base):
|
||||||
|
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
||||||
|
|
||||||
|
subject_label/object_label are plaintext entity identifiers (not user content).
|
||||||
|
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
||||||
|
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_relations"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
||||||
|
source_episode_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False),
|
||||||
|
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Plugin(Base):
|
class Plugin(Base):
|
||||||
"""Plugin marketplace catalog entry."""
|
"""Plugin marketplace catalog entry."""
|
||||||
|
|
||||||
|
|||||||
220
tests/test_memory_relations.py
Normal file
220
tests/test_memory_relations.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Tests for Phase 3 — relational tier (Mem0g-light).
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. upsert_relation inserts a row and query_relations returns it
|
||||||
|
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
|
||||||
|
3. tier gating: Free user gets empty list from query_relations + enrich_context
|
||||||
|
4. enrich_context includes relational_memory key for Pro user
|
||||||
|
5. decay_relations decays confidence and prunes rows below threshold
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_maintenance import decay_relations
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import MemoryRelation, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||||
|
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def pro_user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the pro test user so Fernet works."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def free_user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the free test user."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
|
||||||
|
"""upsert_relation inserts a row; query_relations returns it."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Giulia",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="works_at",
|
||||||
|
object_="Acme Corp",
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.9,
|
||||||
|
)
|
||||||
|
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0].subject_label == "Giulia"
|
||||||
|
assert rows[0].predicate == "works_at"
|
||||||
|
assert rows[0].object_label == "Acme Corp"
|
||||||
|
assert abs(rows[0].confidence - 0.9) < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
|
||||||
|
"""Second upsert on same triple updates confidence and last_confirmed_at."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Marco",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="stakeholder_of",
|
||||||
|
object_="Project Nexus",
|
||||||
|
object_type="project",
|
||||||
|
confidence=0.7,
|
||||||
|
)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Marco",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="stakeholder_of",
|
||||||
|
object_="Project Nexus",
|
||||||
|
object_type="project",
|
||||||
|
confidence=0.95,
|
||||||
|
)
|
||||||
|
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
|
||||||
|
# Only one row despite two upserts
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert abs(rows[0].confidence - 0.95) < 0.001
|
||||||
|
assert rows[0].last_confirmed_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
|
||||||
|
"""Free user: upsert_relation is silently skipped (no row created)."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
FREE_USER_ID,
|
||||||
|
subject="Alice",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="reports_to",
|
||||||
|
object_="Bob",
|
||||||
|
object_type="person",
|
||||||
|
confidence=0.8,
|
||||||
|
)
|
||||||
|
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
|
||||||
|
assert rows == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
|
||||||
|
"""enrich_context includes relational_memory key for Pro user."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
await mm.upsert_relation(
|
||||||
|
PRO_USER_ID,
|
||||||
|
subject="Elena",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="cfo_of",
|
||||||
|
object_="StartupXYZ",
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.85,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||||
|
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
|
||||||
|
|
||||||
|
assert "relational_memory" in ctx
|
||||||
|
assert any("Elena" in r for r in ctx["relational_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
|
||||||
|
"""Free user: relational_memory is empty list in enrich_context."""
|
||||||
|
mm = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||||
|
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
|
||||||
|
|
||||||
|
assert ctx.get("relational_memory") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
|
||||||
|
"""decay_relations reduces confidence on stale rows."""
|
||||||
|
old_date = datetime.now(timezone.utc) - timedelta(days=35)
|
||||||
|
row = MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=PRO_USER_ID,
|
||||||
|
subject_label="OldContact",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="knows",
|
||||||
|
object_label="SomeProject",
|
||||||
|
object_type="project",
|
||||||
|
confidence=0.8,
|
||||||
|
last_confirmed_at=old_date,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await decay_relations(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
|
||||||
|
)
|
||||||
|
updated = result.scalar_one_or_none()
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.confidence < 0.8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
|
||||||
|
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
|
||||||
|
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
|
||||||
|
old_date = datetime.now(timezone.utc) - timedelta(days=65)
|
||||||
|
row = MemoryRelation(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=PRO_USER_ID,
|
||||||
|
subject_label="ExpiredContact",
|
||||||
|
subject_type="person",
|
||||||
|
predicate="used_to_work_with",
|
||||||
|
object_label="OldCorp",
|
||||||
|
object_type="company",
|
||||||
|
confidence=0.21,
|
||||||
|
last_confirmed_at=old_date,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await decay_relations(db_session, PRO_USER_ID)
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
|
||||||
|
)
|
||||||
|
pruned = result.scalar_one_or_none()
|
||||||
|
assert pruned is None
|
||||||
Reference in New Issue
Block a user