memory evolution phase 1
This commit is contained in:
54
alembic/versions/005_associative_pgvector.py
Normal file
54
alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||
|
||||
Migration 004 created the embedding column as vector(1536) and added the
|
||||
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||
1. Ensures the pgvector extension is enabled (idempotent).
|
||||
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||
memory_associative_embedding_idx (creates it only if absent).
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 9a1f2d0b6c7e
|
||||
Create Date: 2026-04-15
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "005"
|
||||
down_revision: Union[str, None] = "e04100e88ace"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE tablename = 'memory_associative'
|
||||
AND indexname = 'memory_associative_embedding_idx'
|
||||
) THEN
|
||||
CREATE INDEX memory_associative_embedding_idx
|
||||
ON memory_associative
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||
@@ -25,6 +25,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -33,6 +34,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -49,6 +52,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
34
app/core/embeddings.py
Normal file
34
app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""OpenAI embedding helper for associative memory tier.
|
||||
|
||||
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||
Returns None on any failure — callers must implement a keyword fallback.
|
||||
Never raises; all exceptions are logged as warnings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_INPUT_CHARS = 8000
|
||||
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
|
||||
async def embed_text(text: str) -> list[float] | None:
|
||||
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||
try:
|
||||
client = AsyncOpenAI()
|
||||
truncated = text[:_MAX_INPUT_CHARS]
|
||||
response = await client.embeddings.create(
|
||||
input=truncated,
|
||||
model=_EMBEDDING_MODEL,
|
||||
)
|
||||
result: list[float] = response.data[0].embedding
|
||||
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||
return None
|
||||
@@ -69,17 +69,19 @@ class MemoryMiddleware:
|
||||
if fernet is None:
|
||||
return {}
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier: str = user_dbg.get("tier") or "free"
|
||||
|
||||
core = await self._load_core(user_id, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
user_tier,
|
||||
len(core),
|
||||
len(associative),
|
||||
len(episodic),
|
||||
@@ -255,6 +257,50 @@ class MemoryMiddleware:
|
||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||
return True
|
||||
|
||||
async def store_associative(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
) -> None:
|
||||
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
|
||||
embedding: list[float] | None = None
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
embedding = await embed_text(content)
|
||||
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=embedding,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: store_associative user=%s embedded=%s",
|
||||
user_id,
|
||||
embedding is not None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
@@ -364,14 +410,49 @@ class MemoryMiddleware:
|
||||
return out
|
||||
|
||||
async def _load_associative(
|
||||
self, user_id: str, message: str, fernet: Fernet
|
||||
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||
) -> list[str]:
|
||||
"""Load top-k associative memories.
|
||||
|
||||
Production: uses pgvector cosine similarity on the message embedding.
|
||||
Current implementation: keyword-based fallback (no external embedding call)
|
||||
so tests pass without a live OpenAI key.
|
||||
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
vec = await embed_text(message)
|
||||
if vec is not None:
|
||||
try:
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(
|
||||
MemoryAssociative.user_id == user_id,
|
||||
MemoryAssociative.embedding.isnot(None),
|
||||
)
|
||||
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
logger.info(
|
||||
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||
user_id,
|
||||
len(out),
|
||||
)
|
||||
return out
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Keyword fallback: most recent rows
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
@@ -379,7 +460,7 @@ class MemoryMiddleware:
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
out = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
|
||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
@@ -299,8 +300,8 @@ class MemoryAssociative(Base):
|
||||
nullable=False, index=True,
|
||||
)
|
||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
@@ -348,3 +349,25 @@ class MemoryProactive(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin marketplace catalog entry."""
|
||||
|
||||
__tablename__ = "plugins"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ google-auth-oauthlib>=1.2.0
|
||||
google-auth-httplib2>=0.2.0
|
||||
msal>=1.28.0
|
||||
cryptography>=42.0.0
|
||||
pgvector>=0.2.5
|
||||
langfuse>=2.0.0
|
||||
beautifulsoup4>=4.12.0
|
||||
lxml>=5.0.0
|
||||
|
||||
@@ -12,13 +12,14 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.embeddings import embed_text
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
@@ -341,3 +342,33 @@ def test_home_request_calls_memory_middleware(client):
|
||||
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||
assert stored_session_id == session_id
|
||||
assert stored_message == "Show tasks"
|
||||
|
||||
|
||||
# ── embed_text ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_returns_1536_floats():
|
||||
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
|
||||
fake_embedding = [0.1] * 1536
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=fake_embedding)]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
|
||||
result = await embed_text("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert all(isinstance(x, float) for x in result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_returns_none_on_failure():
|
||||
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
|
||||
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
|
||||
result = await embed_text("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user