diff --git a/alembic/versions/006_memory_relations.py b/alembic/versions/006_memory_relations.py new file mode 100644 index 0000000..1d9ce84 --- /dev/null +++ b/alembic/versions/006_memory_relations.py @@ -0,0 +1,74 @@ +"""Add memory_relations table (Phase 3 — relational tier). + +Revision ID: 006 +Revises: 1f5975a4f3f4 +Create Date: 2026-04-16 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "006" +down_revision: Union[str, None] = "1f5975a4f3f4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "memory_relations", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("subject_label", sa.String(128), nullable=False), + sa.Column("subject_type", sa.String(32), nullable=False), + sa.Column("predicate", sa.String(64), nullable=False), + sa.Column("object_label", sa.String(128), nullable=False), + sa.Column("object_type", sa.String(32), nullable=False), + sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"), + sa.Column( + "source_episode_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("notes_encrypted", sa.LargeBinary, nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True), + ) + op.create_index( + "memory_relations_user_subject_idx", + "memory_relations", + ["user_id", "subject_label"], + ) + op.create_index( + "memory_relations_user_predicate_idx", + "memory_relations", + ["user_id", "predicate"], + ) + + +def downgrade() -> None: + op.drop_index("memory_relations_user_predicate_idx", "memory_relations") + op.drop_index("memory_relations_user_subject_idx", "memory_relations") + op.drop_table("memory_relations") diff --git a/app/api/routes/memory.py b/app/api/routes/memory.py new file mode 100644 index 0000000..ffc5cfe --- /dev/null +++ b/app/api/routes/memory.py @@ -0,0 +1,225 @@ +"""Memory management routes — view/edit/delete user memory tiers. + +All routes require authentication. Data is always user-scoped. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Annotated + +from fastapi import APIRouter, Depends, Header, HTTPException, status +from pydantic import BaseModel, Field +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.core.memory_middleware import MemoryMiddleware +from app.db import get_session +from app.models import ( + ExtractionQueue, + MemoryAssociative, + MemoryCore, + MemoryEpisodic, + MemoryProactive, + MemoryRelation, +) +from app.schemas import UserProfile + +router = APIRouter(prefix="/memory", tags=["memory"]) + +logger = logging.getLogger(__name__) + +_ALLOWED_PREDICATES = { + "works_at", + "reports_to", + "stakeholder_of", + "last_contacted_on", + "owes_followup", + "manages", + "collaborates_with", + "owns", + "member_of", + "custom", +} + + +# ── Response schemas ───────────────────────────────────────────────────────── + +class RelationOut(BaseModel): + id: str + subject_label: str + subject_type: str + predicate: str + object_label: str + object_type: str + confidence: float + last_confirmed_at: int | None = None # epoch ms + + +class RelationPatch(BaseModel): + subject_label: str | None = None + object_label: str | None = None + predicate: str | None = None + confidence: float | None = Field(None, ge=0.0, le=1.0) + + +class CoreAddBody(BaseModel): + key: str = Field(..., min_length=1, max_length=255) + value: str = Field(..., min_length=1) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _relation_to_out(row: MemoryRelation) -> RelationOut: + last_ms: int | None = None + if row.last_confirmed_at is not None: + last_ms = int(row.last_confirmed_at.timestamp() * 1000) + return RelationOut( + id=row.id, + subject_label=row.subject_label, + subject_type=row.subject_type, + predicate=row.predicate, + object_label=row.object_label, + object_type=row.object_type, + confidence=row.confidence, + last_confirmed_at=last_ms, + ) + + +# ── Routes ─────────────────────────────────────────────────────────────────── + +@router.get("/core", response_model=dict[str, str]) +async def get_core_memory( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, str]: + """Return all core memory k/v pairs (plaintext) for the current user.""" + mw = MemoryMiddleware(db) + blocks = await mw.list_core_blocks(current_user.id) + return {b["label"]: b["value"] for b in blocks} + + +@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_core_key( + key: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> None: + """Delete a single core memory key (GDPR Art. 17).""" + mw = MemoryMiddleware(db) + deleted = await mw.delete_core(current_user.id, key) + if not deleted: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found") + + +@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str]) +async def add_core_key( + body: CoreAddBody, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, str]: + """Add or overwrite a core memory key/value pair.""" + mw = MemoryMiddleware(db) + await mw.update_core(current_user.id, body.key, body.value) + return {body.key: body.value} + + +@router.get("/relational", response_model=list[RelationOut]) +async def get_relational_memory( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[RelationOut]: + """Return all relational memory rows for the current user.""" + mw = MemoryMiddleware(db) + rows = await mw.query_relations(current_user.id, limit=200) + return [_relation_to_out(r) for r in rows] + + +@router.patch("/relational/{relation_id}", response_model=RelationOut) +async def patch_relation( + relation_id: str, + body: RelationPatch, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> RelationOut: + """Edit a relation row's labels, predicate, or confidence.""" + if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}", + ) + + result = await db.execute( + select(MemoryRelation).where( + MemoryRelation.id == relation_id, + MemoryRelation.user_id == current_user.id, + ) + ) + row = result.scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found") + + if body.subject_label is not None: + row.subject_label = body.subject_label + if body.object_label is not None: + row.object_label = body.object_label + if body.predicate is not None: + row.predicate = body.predicate + if body.confidence is not None: + row.confidence = body.confidence + row.last_confirmed_at = datetime.now(timezone.utc) + + await db.commit() + await db.refresh(row) + logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id) + return _relation_to_out(row) + + +@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_relation( + relation_id: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> None: + """Hard-delete a relation row (GDPR Art. 17).""" + result = await db.execute( + select(MemoryRelation).where( + MemoryRelation.id == relation_id, + MemoryRelation.user_id == current_user.id, + ) + ) + row = result.scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found") + await db.delete(row) + await db.commit() + logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id) + + +@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT) +async def forget_all( + x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> None: + """Wipe all memory tiers for the current user (GDPR Art. 17). + + Requires ``X-Confirm: true`` header. Does NOT delete the user account. + """ + if x_confirm != "true": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.", + ) + + uid = current_user.id + await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid)) + await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid)) + await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid)) + await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid)) + await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid)) + await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid)) + await db.commit() + logger.warning("memory: forget_all GDPR wipe user=%s", uid) diff --git a/app/billing/tier_manager.py b/app/billing/tier_manager.py index 859d378..aae46e3 100644 --- a/app/billing/tier_manager.py +++ b/app/billing/tier_manager.py @@ -25,8 +25,9 @@ FEATURES: dict[str, dict[str, Any]] = { "providers": 1, "batch_builder": False, "sso": False, - "real_embeddings": False, # keyword fallback only - "realtime_extraction": False, # batch queue (Phase 2) + "real_embeddings": False, # keyword fallback only + "realtime_extraction": False, # batch queue (Phase 2) + "relational_memory": False, # relational tier (Phase 3) — Pro+ }, "pro": { "agents": -1, # unlimited @@ -35,8 +36,9 @@ FEATURES: dict[str, dict[str, Any]] = { "providers": -1, "batch_builder": False, "sso": False, - "real_embeddings": True, # pgvector cosine search - "realtime_extraction": True, # fire-and-forget asyncio.create_task + "real_embeddings": True, # pgvector cosine search + "realtime_extraction": True, # fire-and-forget asyncio.create_task + "relational_memory": True, # person/project predicates }, "power": { "agents": -1, @@ -47,6 +49,7 @@ FEATURES: dict[str, dict[str, Any]] = { "sso": False, "real_embeddings": True, "realtime_extraction": True, + "relational_memory": True, # all predicates incl. custom }, "team": { "agents": -1, @@ -57,6 +60,7 @@ FEATURES: dict[str, dict[str, Any]] = { "sso": True, "real_embeddings": True, "realtime_extraction": True, + "relational_memory": True, # all predicates incl. custom }, } diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index 602d418..44a99be 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -55,6 +55,22 @@ def _language_instruction(context: dict[str, Any]) -> str: f"All your output text must be written in {lang}." ) +def _relational_memory_injection(context: dict[str, Any]) -> str: + """Return a system-prompt paragraph listing known people/projects from relational memory. + + Returns empty string when no relational rows or tier is Free. + Capped at 800 chars to control token spend. + """ + relations: list[str] = context.get("relational_memory") or [] + if not relations: + return "" + body = "\n".join(f"- {r}" for r in relations) + section = f"\n\nKnown people & projects:\n{body}" + if len(section) > 800: + section = section[:797] + "..." + return section + + _HOME_SYSTEM_PROMPT = ( "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " "Always use tools for factual data retrieval before answering. " @@ -904,6 +920,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: system_prompt, langfuse_prompt = get_prompt_or_fallback( "home_system", _HOME_SYSTEM_PROMPT ) + system_prompt += _relational_memory_injection(context) system_prompt += _language_instruction(context) response = await _run_single_agent( user_id=user_id, @@ -922,6 +939,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t system_prompt, langfuse_prompt = get_prompt_or_fallback( "floating_system", _FLOATING_SYSTEM_PROMPT ) + system_prompt += _relational_memory_injection(context) system_prompt += _language_instruction(context) response = await _run_single_agent( user_id=user_id, @@ -946,6 +964,7 @@ async def run_home_stream( system_prompt, langfuse_prompt = get_prompt_or_fallback( "home_system", _HOME_SYSTEM_PROMPT ) + system_prompt += _relational_memory_injection(context) system_prompt += _language_instruction(context) text_chunks: list[str] = [] async for event in _run_single_agent_stream( @@ -979,6 +998,7 @@ async def run_floating_stream( system_prompt, langfuse_prompt = get_prompt_or_fallback( "floating_system", _FLOATING_SYSTEM_PROMPT ) + system_prompt += _relational_memory_injection(context) system_prompt += _language_instruction(context) sanitizer = _FloatingStreamSanitizer() emitted_sanitized = False diff --git a/app/core/memory_extraction.py b/app/core/memory_extraction.py index 1345b04..0c3bb85 100644 --- a/app/core/memory_extraction.py +++ b/app/core/memory_extraction.py @@ -366,7 +366,7 @@ async def _apply_candidate( if candidate.target_tier == "relational": # Always upsert relations — decide_action skipped (no neighbour search). if candidate.subject and candidate.predicate and candidate.object: - await _upsert_relation_stub( + await _upsert_relation( middleware, db, user_id, candidate, trace_id ) return @@ -396,35 +396,29 @@ def _content_to_key(content: str) -> str: return slug or "memory" -async def _upsert_relation_stub( +async def _upsert_relation( middleware: Any, db: AsyncSession, user_id: str, candidate: MemoryCandidate, trace_id: str | None, ) -> None: - """Stub: upsert_relation will be fully wired in Phase 3. - - Called here so Phase 2 extraction pipeline already routes relation candidates - correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation(). - """ - if hasattr(middleware, "upsert_relation"): - await middleware.upsert_relation( - user_id=user_id, - subject=candidate.subject, - subject_type="unknown", - predicate=candidate.predicate, - object_=candidate.object, - object_type="unknown", - confidence=candidate.confidence, - ) - else: - logger.info( - "memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s", - candidate.subject, - candidate.predicate, - candidate.object, - ) + """Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3).""" + await middleware.upsert_relation( + user_id=user_id, + subject=candidate.subject or "unknown", + subject_type="unknown", + predicate=candidate.predicate or "related_to", + object_=candidate.object or "unknown", + object_type="unknown", + confidence=candidate.confidence, + ) + logger.info( + "memory_extraction: upserted relation subject=%s predicate=%s object=%s", + candidate.subject, + candidate.predicate, + candidate.object, + ) async def _store_proactive_stub( diff --git a/app/core/memory_maintenance.py b/app/core/memory_maintenance.py new file mode 100644 index 0000000..c9a8ceb --- /dev/null +++ b/app/core/memory_maintenance.py @@ -0,0 +1,102 @@ +"""Memory maintenance jobs — Phase 3/5. + +Two entrypoints called by the scheduler (APScheduler) registered in app/main.py: + + drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5). + decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3). + +Both are safe to call manually or from tests; they never raise. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import MemoryRelation + +logger = logging.getLogger(__name__) + +# Decay parameters +_DECAY_FACTOR = 0.95 # multiply confidence by this every _DECAY_PERIOD days +_DECAY_PERIOD_DAYS = 30 # period for one decay step +_PRUNE_THRESHOLD = 0.2 # rows below this confidence are deleted + + +async def decay_relations(db: AsyncSession, user_id: str) -> None: + """Apply confidence decay to all relation rows for a user. + + Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at. + Rows whose confidence falls below 0.2 are deleted. + + Never raises — wraps in try/except. + """ + try: + await _decay_relations_inner(db, user_id) + except Exception as exc: + logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc) + + +async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None: + result = await db.execute( + select(MemoryRelation).where(MemoryRelation.user_id == user_id) + ) + rows = result.scalars().all() + now = datetime.now(timezone.utc) + deleted = 0 + decayed = 0 + + for row in rows: + reference = row.last_confirmed_at or row.created_at + if reference is None: + continue + # Ensure timezone-aware comparison + if reference.tzinfo is None: + reference = reference.replace(tzinfo=timezone.utc) + + days_elapsed = (now - reference).days + if days_elapsed < _DECAY_PERIOD_DAYS: + continue + + periods = days_elapsed // _DECAY_PERIOD_DAYS + new_confidence = row.confidence * (_DECAY_FACTOR ** periods) + + if new_confidence < _PRUNE_THRESHOLD: + await db.delete(row) + deleted += 1 + logger.info( + "memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s " + "confidence=%.3f (below threshold)", + row.id, user_id, row.subject_label, row.predicate, new_confidence, + ) + else: + row.confidence = new_confidence + decayed += 1 + + try: + await db.commit() + logger.info( + "memory_maintenance: decay_relations user=%s decayed=%d deleted=%d", + user_id, decayed, deleted, + ) + except Exception as exc: + logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc) + await db.rollback() + + +async def drain_extraction_queue(db: AsyncSession) -> None: + """Process pending ExtractionQueue rows for Free-tier users (Phase 5 stub). + + Full implementation wired in Phase 5 when APScheduler is registered. + Currently logs count and returns. + """ + try: + from app.models import ExtractionQueue # noqa: PLC0415 + result = await db.execute(select(ExtractionQueue)) + rows = result.scalars().all() + logger.info("memory_maintenance: drain_extraction_queue pending=%d (Phase 5 cron)", len(rows)) + except Exception as exc: + logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc) diff --git a/app/core/memory_middleware.py b/app/core/memory_middleware.py index 9780faa..02806c3 100644 --- a/app/core/memory_middleware.py +++ b/app/core/memory_middleware.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio import logging import uuid +from datetime import datetime, timezone from typing import Any from cryptography.fernet import Fernet, InvalidToken @@ -33,11 +34,17 @@ from app.models import ( MemoryCore, MemoryEpisodic, MemoryProactive, + MemoryRelation, User, ) logger = logging.getLogger(__name__) + +def _now() -> datetime: + return datetime.now(timezone.utc) + + # Tuning constants _ASSOCIATIVE_TOP_K = 5 _EPISODIC_RECENT_N = 10 @@ -66,6 +73,7 @@ class MemoryMiddleware: associative_memory — [plaintext_content, ...] (top-k by keyword match) episodic_memory — [plaintext_summary, ...] (most recent N) proactive_hints — [plaintext_pattern, ...] (above threshold) + relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+) """ fernet = await self._get_fernet(user_id) if fernet is None: @@ -78,9 +86,10 @@ class MemoryMiddleware: associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier) episodic = await self._load_episodic(user_id, fernet, session_id=session_id) proactive = await self._load_proactive(user_id, fernet) + relational = await self._load_relational(user_id, user_tier=user_tier) logger.info( - "memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d", + "memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d", trace_id or "-", user_id, user_tier, @@ -88,6 +97,7 @@ class MemoryMiddleware: len(associative), len(episodic), len(proactive), + len(relational), ) return { @@ -95,6 +105,7 @@ class MemoryMiddleware: "associative_memory": associative, "episodic_memory": episodic, "proactive_hints": proactive, + "relational_memory": relational, } async def store_episode( @@ -375,6 +386,99 @@ class MemoryMiddleware: logger.error("memory: store_associative failed user=%s: %s", user_id, exc) await self._db.rollback() + async def upsert_relation( + self, + user_id: str, + subject: str, + subject_type: str, + predicate: str, + object_: str, + object_type: str, + *, + confidence: float = 0.7, + source_episode_id: str | None = None, + notes: str | None = None, + ) -> None: + """Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label). + + subject_label / object_label are plaintext entity identifiers — not encrypted. + notes is optional; encrypted with user Fernet if provided. + """ + from app.billing.tier_manager import tier_manager # noqa: PLC0415 + + user_dbg = await self._get_user_debug(user_id) + user_tier = user_dbg.get("tier") or "free" + if not tier_manager.check_feature(user_tier, "relational_memory"): + logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier) + return + + notes_encrypted: bytes | None = None + if notes: + fernet = await self._get_fernet(user_id) + if fernet: + notes_encrypted = fernet.encrypt(notes.encode()) + + result = await self._db.execute( + select(MemoryRelation).where( + MemoryRelation.user_id == user_id, + MemoryRelation.subject_label == subject, + MemoryRelation.predicate == predicate, + MemoryRelation.object_label == object_, + ) + ) + existing = result.scalar_one_or_none() + + if existing is not None: + existing.subject_type = subject_type + existing.object_type = object_type + existing.confidence = confidence + existing.last_confirmed_at = _now() + if notes_encrypted is not None: + existing.notes_encrypted = notes_encrypted + else: + self._db.add(MemoryRelation( + id=str(uuid.uuid4()), + user_id=user_id, + subject_label=subject, + subject_type=subject_type, + predicate=predicate, + object_label=object_, + object_type=object_type, + confidence=confidence, + source_episode_id=source_episode_id, + notes_encrypted=notes_encrypted, + )) + + try: + await self._db.commit() + logger.info( + "memory: upsert_relation user=%s subject=%s predicate=%s object=%s", + user_id, subject, predicate, object_, + ) + except Exception as exc: + logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc) + await self._db.rollback() + + async def query_relations( + self, + user_id: str, + subject: str | None = None, + predicate: str | None = None, + object_: str | None = None, + limit: int = 20, + ) -> list[MemoryRelation]: + """Query relation rows for a user with optional filters.""" + q = select(MemoryRelation).where(MemoryRelation.user_id == user_id) + if subject is not None: + q = q.where(MemoryRelation.subject_label == subject) + if predicate is not None: + q = q.where(MemoryRelation.predicate == predicate) + if object_ is not None: + q = q.where(MemoryRelation.object_label == object_) + q = q.order_by(MemoryRelation.confidence.desc()).limit(limit) + result = await self._db.execute(q) + return list(result.scalars().all()) + async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None: """Insert a long-term archival memory entry.""" fernet = await self._get_fernet(user_id) @@ -463,13 +567,26 @@ class MemoryMiddleware: async def _get_user_debug(self, user_id: str) -> dict[str, str | None]: """Load lightweight user debug fields for trace logs.""" + from app.config.settings import settings # noqa: PLC0415 + from app.models import Subscription # noqa: PLC0415 + result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return {"tier": None} - return { - "tier": user.tier, - } + + sub_result = await self._db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + sub_tier: str | None = sub_result.scalar_one_or_none() + if sub_tier: + tier = sub_tier + elif settings.ENV == "dev": + tier = "power" + else: + tier = user.tier or "free" + + return {"tier": tier} async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: result = await self._db.execute( @@ -563,6 +680,26 @@ class MemoryMiddleware: out.append(plaintext) return out + async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]: + """Return top-10 relation strings for Pro+ users; empty list for Free.""" + from app.billing.tier_manager import tier_manager # noqa: PLC0415 + + if not tier_manager.check_feature(user_tier, "relational_memory"): + return [] + + result = await self._db.execute( + select(MemoryRelation) + .where(MemoryRelation.user_id == user_id) + .order_by(MemoryRelation.confidence.desc()) + .limit(10) + ) + rows = result.scalars().all() + out = [ + f"{r.subject_label} --{r.predicate}--> {r.object_label}" + for r in rows + ] + return out + async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryProactive) diff --git a/app/main.py b/app/main.py index 68fab9a..c22a1a8 100644 --- a/app/main.py +++ b/app/main.py @@ -50,13 +50,14 @@ def create_app() -> FastAPI: app.add_middleware(SanitizerMiddleware) app.add_middleware(TierRateLimitMiddleware) - from app.api.routes import agents, auth, billing, chat, device_ws + from app.api.routes import agents, auth, billing, chat, device_ws, memory app.include_router(auth.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1") app.include_router(billing.router, prefix="/api/v1") app.include_router(agents.router, prefix="/api/v1") app.include_router(device_ws.router, prefix="/api/v1") + app.include_router(memory.router, prefix="/api/v1") @app.get("/api/v1/health", tags=["health"]) async def health() -> dict: diff --git a/app/models.py b/app/models.py index d5f6f77..b00cec9 100644 --- a/app/models.py +++ b/app/models.py @@ -14,6 +14,7 @@ Table inventory: memory_associative — per-user semantic memory with embeddings (encrypted) memory_episodic — per-user session summaries (encrypted) memory_proactive — per-user behavioral patterns (encrypted) + memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3) """ from __future__ import annotations @@ -30,6 +31,7 @@ from sqlalchemy import ( ForeignKey, Integer, JSON, + LargeBinary, String, Text, Uuid, @@ -373,6 +375,44 @@ class ExtractionQueue(Base): ) +class MemoryRelation(Base): + """Per-user entity/relation graph row (Mem0g-light, Phase 3). + + subject_label/object_label are plaintext entity identifiers (not user content). + notes_encrypted is optional Fernet-encrypted per-user commentary. + confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at. + """ + + __tablename__ = "memory_relations" + + id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid) + user_id: Mapped[str] = mapped_column( + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, index=True, + ) + subject_label: Mapped[str] = mapped_column(String(128), nullable=False) + subject_type: Mapped[str] = mapped_column(String(32), nullable=False) + predicate: Mapped[str] = mapped_column(String(64), nullable=False) + object_label: Mapped[str] = mapped_column(String(128), nullable=False) + object_type: Mapped[str] = mapped_column(String(32), nullable=False) + confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7) + source_episode_id: Mapped[str | None] = mapped_column( + Uuid(as_uuid=False), + ForeignKey("memory_episodic.id", ondelete="SET NULL"), + nullable=True, + ) + notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + last_confirmed_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + class Plugin(Base): """Plugin marketplace catalog entry.""" diff --git a/tests/test_memory_relations.py b/tests/test_memory_relations.py new file mode 100644 index 0000000..da0ec23 --- /dev/null +++ b/tests/test_memory_relations.py @@ -0,0 +1,220 @@ +"""Tests for Phase 3 — relational tier (Mem0g-light). + +Coverage: + 1. upsert_relation inserts a row and query_relations returns it + 2. upsert_relation updates existing row on duplicate (subject/predicate/object) + 3. tier gating: Free user gets empty list from query_relations + enrich_context + 4. enrich_context includes relational_memory key for Pro user + 5. decay_relations decays confidence and prunes rows below threshold +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest +import pytest_asyncio +from cryptography.fernet import Fernet +from sqlalchemy import select + +from app.core.memory_maintenance import decay_relations +from app.core.memory_middleware import MemoryMiddleware +from app.db import get_session +from app.main import app +from app.models import MemoryRelation, User +from tests.conftest import TEST_USER_IDS + +PRO_USER_ID = TEST_USER_IDS["pro"] +FREE_USER_ID = TEST_USER_IDS["free"] +_FERNET_KEY = Fernet.generate_key().decode() + + +# ── DB override ─────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _override_db(db_session): + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + +@pytest_asyncio.fixture +async def pro_user_with_key(db_session): + """Set encryption_key on the pro test user so Fernet works.""" + result = await db_session.execute(select(User).where(User.id == PRO_USER_ID)) + user = result.scalar_one() + user.encryption_key = _FERNET_KEY + await db_session.commit() + return user + + +@pytest_asyncio.fixture +async def free_user_with_key(db_session): + """Set encryption_key on the free test user.""" + result = await db_session.execute(select(User).where(User.id == FREE_USER_ID)) + user = result.scalar_one() + user.encryption_key = _FERNET_KEY + await db_session.commit() + return user + + +# ── Tests ───────────────────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key): + """upsert_relation inserts a row; query_relations returns it.""" + mm = MemoryMiddleware(db_session) + await mm.upsert_relation( + PRO_USER_ID, + subject="Giulia", + subject_type="person", + predicate="works_at", + object_="Acme Corp", + object_type="company", + confidence=0.9, + ) + rows = await mm.query_relations(PRO_USER_ID, subject="Giulia") + assert len(rows) == 1 + assert rows[0].subject_label == "Giulia" + assert rows[0].predicate == "works_at" + assert rows[0].object_label == "Acme Corp" + assert abs(rows[0].confidence - 0.9) < 0.001 + + +@pytest.mark.asyncio +async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key): + """Second upsert on same triple updates confidence and last_confirmed_at.""" + mm = MemoryMiddleware(db_session) + await mm.upsert_relation( + PRO_USER_ID, + subject="Marco", + subject_type="person", + predicate="stakeholder_of", + object_="Project Nexus", + object_type="project", + confidence=0.7, + ) + await mm.upsert_relation( + PRO_USER_ID, + subject="Marco", + subject_type="person", + predicate="stakeholder_of", + object_="Project Nexus", + object_type="project", + confidence=0.95, + ) + rows = await mm.query_relations(PRO_USER_ID, subject="Marco") + # Only one row despite two upserts + assert len(rows) == 1 + assert abs(rows[0].confidence - 0.95) < 0.001 + assert rows[0].last_confirmed_at is not None + + +@pytest.mark.asyncio +async def test_free_tier_relation_skipped(db_session, free_user_with_key): + """Free user: upsert_relation is silently skipped (no row created).""" + mm = MemoryMiddleware(db_session) + await mm.upsert_relation( + FREE_USER_ID, + subject="Alice", + subject_type="person", + predicate="reports_to", + object_="Bob", + object_type="person", + confidence=0.8, + ) + rows = await mm.query_relations(FREE_USER_ID, subject="Alice") + assert rows == [] + + +@pytest.mark.asyncio +async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key): + """enrich_context includes relational_memory key for Pro user.""" + mm = MemoryMiddleware(db_session) + await mm.upsert_relation( + PRO_USER_ID, + subject="Elena", + subject_type="person", + predicate="cfo_of", + object_="StartupXYZ", + object_type="company", + confidence=0.85, + ) + + with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]): + ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?") + + assert "relational_memory" in ctx + assert any("Elena" in r for r in ctx["relational_memory"]) + + +@pytest.mark.asyncio +async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key): + """Free user: relational_memory is empty list in enrich_context.""" + mm = MemoryMiddleware(db_session) + + with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]): + ctx = await mm.enrich_context(FREE_USER_ID, "test message") + + assert ctx.get("relational_memory") == [] + + +@pytest.mark.asyncio +async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key): + """decay_relations reduces confidence on stale rows.""" + old_date = datetime.now(timezone.utc) - timedelta(days=35) + row = MemoryRelation( + id=str(uuid.uuid4()), + user_id=PRO_USER_ID, + subject_label="OldContact", + subject_type="person", + predicate="knows", + object_label="SomeProject", + object_type="project", + confidence=0.8, + last_confirmed_at=old_date, + ) + db_session.add(row) + await db_session.commit() + + await decay_relations(db_session, PRO_USER_ID) + + result = await db_session.execute( + select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact") + ) + updated = result.scalar_one_or_none() + assert updated is not None + assert updated.confidence < 0.8 + + +@pytest.mark.asyncio +async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key): + """decay_relations deletes rows whose confidence drops below 0.2 threshold.""" + # Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned + old_date = datetime.now(timezone.utc) - timedelta(days=65) + row = MemoryRelation( + id=str(uuid.uuid4()), + user_id=PRO_USER_ID, + subject_label="ExpiredContact", + subject_type="person", + predicate="used_to_work_with", + object_label="OldCorp", + object_type="company", + confidence=0.21, + last_confirmed_at=old_date, + ) + db_session.add(row) + await db_session.commit() + + await decay_relations(db_session, PRO_USER_ID) + + result = await db_session.execute( + select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact") + ) + pruned = result.scalar_one_or_none() + assert pruned is None