step-6: add memory models and migration (models.py, alembic)
- User.encryption_key: per-user Fernet key generated on registration - MemoryCore: encrypted key/value preferences - MemoryAssociative: encrypted semantic memory + pgvector(1536) embedding - MemoryEpisodic: encrypted session summaries - MemoryProactive: encrypted behavioral patterns with confidence score - Migration 004: enables pgvector extension, creates all 4 tables + ivfflat index - auth.py register: generates Fernet key for new users - 8 unit tests pass (SQLite in-memory, JSON embedding fallback) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -285,7 +285,7 @@ pytest tests/test_memory_models.py
|
||||
```
|
||||
|
||||
**Status**:
|
||||
- [ ] Step 6 complete
|
||||
- [x] Step 6 complete
|
||||
|
||||
**Commit**: After tests pass, commit with:
|
||||
```
|
||||
|
||||
144
alembic/versions/004_add_memory_tables.py
Normal file
144
alembic/versions/004_add_memory_tables.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Add memory tables and user encryption_key column.
|
||||
|
||||
Memory tables:
|
||||
memory_core — per-user key/value preferences (encrypted)
|
||||
memory_associative — semantic memory with pgvector embedding (encrypted)
|
||||
memory_episodic — session summaries (encrypted)
|
||||
memory_proactive — behavioral patterns (encrypted)
|
||||
|
||||
Also adds encryption_key column to users table.
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2026-03-08
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "004"
|
||||
down_revision: Union[str, None] = "003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enable pgvector extension (idempotent) ────────────────────────────────
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# ── Add encryption_key to users ───────────────────────────────────────────
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("encryption_key", sa.String(64), nullable=True),
|
||||
)
|
||||
|
||||
# ── memory_core ───────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"memory_core",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("key", sa.String(255), nullable=False),
|
||||
sa.Column("value_encrypted", sa.Text, nullable=False),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
|
||||
|
||||
# ── memory_associative ────────────────────────────────────────────────────
|
||||
# The embedding column uses pgvector's vector(1536) type.
|
||||
op.create_table(
|
||||
"memory_associative",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("content_encrypted", sa.Text, nullable=False),
|
||||
sa.Column("entity_type", sa.String(100), nullable=True),
|
||||
sa.Column("entity_id", sa.String(255), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
# Add the pgvector column separately (not supported by generic sa types)
|
||||
op.execute(
|
||||
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
|
||||
)
|
||||
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
|
||||
# IVFFlat index for approximate nearest-neighbour search
|
||||
op.execute(
|
||||
"CREATE INDEX ix_memory_associative_embedding "
|
||||
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
|
||||
)
|
||||
|
||||
# ── memory_episodic ───────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"memory_episodic",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("summary_encrypted", sa.Text, nullable=False),
|
||||
sa.Column("session_id", sa.String(255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
|
||||
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
|
||||
|
||||
# ── memory_proactive ──────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"memory_proactive",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.String(36),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("pattern_encrypted", sa.Text, nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
|
||||
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("memory_proactive")
|
||||
op.drop_table("memory_episodic")
|
||||
op.drop_index("ix_memory_associative_embedding", "memory_associative")
|
||||
op.drop_table("memory_associative")
|
||||
op.drop_table("memory_core")
|
||||
op.drop_column("users", "encryption_key")
|
||||
@@ -13,6 +13,7 @@ import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
@@ -94,6 +95,7 @@ async def register(
|
||||
email=body.email,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
@@ -14,6 +14,10 @@ Table inventory:
|
||||
plugin_installations — per-user install records
|
||||
plugin_reviews — admin review decisions
|
||||
revenue_events — Stripe Connect 70/30 split ledger
|
||||
memory_core — per-user persistent key/value preferences (encrypted)
|
||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||
memory_episodic — per-user session summaries (encrypted)
|
||||
memory_proactive — per-user behavioral patterns (encrypted)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -74,6 +78,9 @@ class User(Base):
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||
# Used to encrypt/decrypt all memory rows for this user.
|
||||
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -375,3 +382,93 @@ class AgentRunLog(Base):
|
||||
foreign_keys="AgentRunLog.agent_id",
|
||||
overlaps="run_logs,local_agent",
|
||||
)
|
||||
|
||||
|
||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MemoryCore(Base):
|
||||
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||
|
||||
Examples: preferred_language, timezone, work_style.
|
||||
Decrypted in-memory only using User.encryption_key.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_core"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryAssociative(Base):
|
||||
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||
|
||||
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||
Tests (SQLite): stored as JSON list.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_associative"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryEpisodic(Base):
|
||||
"""Per-user session summaries, encrypted at rest.
|
||||
|
||||
One row per session interaction; used to recall recent conversations.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_episodic"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryProactive(Base):
|
||||
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||
|
||||
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_proactive"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
205
tests/test_memory_models.py
Normal file
205
tests/test_memory_models.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Tests for Step 6 — memory ORM models and User.encryption_key.
|
||||
|
||||
Uses the SQLite in-memory test DB (from conftest). The pgvector embedding
|
||||
column is stored as JSON in tests (SQLite-compatible).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models import MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
|
||||
USER_ID = TEST_USER_IDS["power"]
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _fernet_key() -> str:
|
||||
return Fernet.generate_key().decode()
|
||||
|
||||
|
||||
def _encrypt(key: str, plaintext: str) -> str:
|
||||
return Fernet(key.encode()).encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def _decrypt(key: str, ciphertext: str) -> str:
|
||||
return Fernet(key.encode()).decrypt(ciphertext.encode()).decode()
|
||||
|
||||
|
||||
# ── User.encryption_key ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_encryption_key_column_exists(db_session):
|
||||
"""User model has encryption_key column and it can be set."""
|
||||
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||
user = result.scalar_one()
|
||||
# Column exists (may be None for seeded users)
|
||||
assert hasattr(user, "encryption_key")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_encryption_key_can_be_set(db_session):
|
||||
key = _fernet_key()
|
||||
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = key
|
||||
await db_session.commit()
|
||||
|
||||
result2 = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||
user2 = result2.scalar_one()
|
||||
assert user2.encryption_key == key
|
||||
|
||||
|
||||
# ── MemoryCore ────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_core_create_and_read(db_session):
|
||||
key = _fernet_key()
|
||||
encrypted_val = _encrypt(key, "UTC")
|
||||
|
||||
row = MemoryCore(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
key="timezone",
|
||||
value_encrypted=encrypted_val,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == USER_ID)
|
||||
)
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.key == "timezone"
|
||||
assert _decrypt(key, fetched.value_encrypted) == "UTC"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_core_cascade_delete(db_session):
|
||||
"""Deleting a user cascades to memory_core."""
|
||||
row = MemoryCore(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
key="lang",
|
||||
value_encrypted="enc",
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
user = (await db_session.execute(select(User).where(User.id == USER_ID))).scalar_one()
|
||||
await db_session.delete(user)
|
||||
await db_session.commit()
|
||||
|
||||
remaining = (
|
||||
await db_session.execute(select(MemoryCore).where(MemoryCore.user_id == USER_ID))
|
||||
).scalars().all()
|
||||
assert remaining == []
|
||||
|
||||
|
||||
# ── MemoryAssociative ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_associative_create_and_read(db_session):
|
||||
key = _fernet_key()
|
||||
content = _encrypt(key, "User prefers morning meetings")
|
||||
embedding = [0.1] * 1536 # fake embedding
|
||||
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
content_encrypted=content,
|
||||
embedding=embedding,
|
||||
entity_type="preference",
|
||||
entity_id=None,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == USER_ID)
|
||||
)
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.entity_type == "preference"
|
||||
assert _decrypt(key, fetched.content_encrypted) == "User prefers morning meetings"
|
||||
assert len(fetched.embedding) == 1536
|
||||
|
||||
|
||||
# ── MemoryEpisodic ────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_episodic_create_and_read(db_session):
|
||||
key = _fernet_key()
|
||||
session_id = str(uuid.uuid4())
|
||||
summary = _encrypt(key, "User asked about Q1 tasks")
|
||||
|
||||
row = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
summary_encrypted=summary,
|
||||
session_id=session_id,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||
)
|
||||
fetched = result.scalar_one()
|
||||
assert _decrypt(key, fetched.summary_encrypted) == "User asked about Q1 tasks"
|
||||
assert isinstance(fetched.created_at, datetime)
|
||||
|
||||
|
||||
# ── MemoryProactive ───────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_proactive_create_and_read(db_session):
|
||||
key = _fernet_key()
|
||||
pattern = _encrypt(key, "User always assigns tasks to self")
|
||||
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=pattern,
|
||||
confidence=0.85,
|
||||
source="inferred",
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryProactive).where(MemoryProactive.user_id == USER_ID)
|
||||
)
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.confidence == pytest.approx(0.85)
|
||||
assert fetched.source == "inferred"
|
||||
assert _decrypt(key, fetched.pattern_encrypted) == "User always assigns tasks to self"
|
||||
|
||||
|
||||
# ── Auth registration generates encryption_key ───────────────────────────────
|
||||
|
||||
def test_register_sets_encryption_key(client):
|
||||
"""POST /api/v1/auth/register creates a user with a valid Fernet key."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "newuser@test.com", "password": "testpassword123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
# Fetch the newly created user via the access token
|
||||
token = resp.json()["access_token"]
|
||||
me_resp = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert me_resp.status_code == 200
|
||||
# We can't see encryption_key in the API response (not in UserProfile),
|
||||
# but we verify registration didn't crash — key generation is implicit.
|
||||
Reference in New Issue
Block a user