Compare commits
1 Commits
feature/ba
...
2d8abb6311
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d8abb6311 |
54
alembic/versions/005_associative_pgvector.py
Normal file
54
alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||||
|
|
||||||
|
Migration 004 created the embedding column as vector(1536) and added the
|
||||||
|
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||||
|
1. Ensures the pgvector extension is enabled (idempotent).
|
||||||
|
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||||
|
memory_associative_embedding_idx (creates it only if absent).
|
||||||
|
|
||||||
|
Revision ID: 005
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-15
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "005"
|
||||||
|
down_revision: Union[str, None] = "e04100e88ace"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||||
|
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||||
|
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE tablename = 'memory_associative'
|
||||||
|
AND indexname = 'memory_associative_embedding_idx'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX memory_associative_embedding_idx
|
||||||
|
ON memory_associative
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||||
@@ -25,6 +25,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": False, # keyword fallback only
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
@@ -33,6 +34,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True, # pgvector cosine search
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True,
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -49,6 +52,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"sso": True,
|
"sso": True,
|
||||||
|
"real_embeddings": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
34
app/core/embeddings.py
Normal file
34
app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""OpenAI embedding helper for associative memory tier.
|
||||||
|
|
||||||
|
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||||
|
Returns None on any failure — callers must implement a keyword fallback.
|
||||||
|
Never raises; all exceptions are logged as warnings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MAX_INPUT_CHARS = 8000
|
||||||
|
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_text(text: str) -> list[float] | None:
|
||||||
|
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||||
|
try:
|
||||||
|
client = AsyncOpenAI()
|
||||||
|
truncated = text[:_MAX_INPUT_CHARS]
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
input=truncated,
|
||||||
|
model=_EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
result: list[float] = response.data[0].embedding
|
||||||
|
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -69,17 +69,19 @@ class MemoryMiddleware:
|
|||||||
if fernet is None:
|
if fernet is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier: str = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
user_id,
|
user_id,
|
||||||
user_dbg.get("tier") or "-",
|
user_tier,
|
||||||
len(core),
|
len(core),
|
||||||
len(associative),
|
len(associative),
|
||||||
len(episodic),
|
len(episodic),
|
||||||
@@ -255,6 +257,50 @@ class MemoryMiddleware:
|
|||||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def store_associative(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
content: str,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
entity_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
embedding = await embed_text(content)
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: store_associative user=%s embedded=%s",
|
||||||
|
user_id,
|
||||||
|
embedding is not None,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
"""Insert a long-term archival memory entry."""
|
"""Insert a long-term archival memory entry."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
@@ -364,14 +410,49 @@ class MemoryMiddleware:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_associative(
|
async def _load_associative(
|
||||||
self, user_id: str, message: str, fernet: Fernet
|
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Load top-k associative memories.
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
Production: uses pgvector cosine similarity on the message embedding.
|
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||||
Current implementation: keyword-based fallback (no external embedding call)
|
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||||
so tests pass without a live OpenAI key.
|
|
||||||
"""
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
vec = await embed_text(message)
|
||||||
|
if vec is not None:
|
||||||
|
try:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(
|
||||||
|
MemoryAssociative.user_id == user_id,
|
||||||
|
MemoryAssociative.embedding.isnot(None),
|
||||||
|
)
|
||||||
|
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
logger.info(
|
||||||
|
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||||
|
user_id,
|
||||||
|
len(out),
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keyword fallback: most recent rows
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryAssociative)
|
select(MemoryAssociative)
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
@@ -379,7 +460,7 @@ class MemoryMiddleware:
|
|||||||
.limit(_ASSOCIATIVE_TOP_K)
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
)
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
out: list[str] = []
|
out = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
if plaintext is not None:
|
if plaintext is not None:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
@@ -299,8 +300,8 @@ class MemoryAssociative(Base):
|
|||||||
nullable=False, index=True,
|
nullable=False, index=True,
|
||||||
)
|
)
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
@@ -348,3 +349,25 @@ class MemoryProactive(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
"""Plugin marketplace catalog entry."""
|
||||||
|
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
|
pgvector>=0.2.5
|
||||||
langfuse>=2.0.0
|
langfuse>=2.0.0
|
||||||
beautifulsoup4>=4.12.0
|
beautifulsoup4>=4.12.0
|
||||||
lxml>=5.0.0
|
lxml>=5.0.0
|
||||||
|
|||||||
@@ -12,13 +12,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.embeddings import embed_text
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
@@ -341,3 +342,33 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||||
assert stored_session_id == session_id
|
assert stored_session_id == session_id
|
||||||
assert stored_message == "Show tasks"
|
assert stored_message == "Show tasks"
|
||||||
|
|
||||||
|
|
||||||
|
# ── embed_text ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_text_returns_1536_floats():
|
||||||
|
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
|
||||||
|
fake_embedding = [0.1] * 1536
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock(embedding=fake_embedding)]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
|
||||||
|
result = await embed_text("test text")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 1536
|
||||||
|
assert all(isinstance(x, float) for x in result)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_text_returns_none_on_failure():
|
||||||
|
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
|
||||||
|
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
|
||||||
|
result = await embed_text("test text")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|||||||
Reference in New Issue
Block a user