- Add app/core/deep_agent.py with Home and Floating supervisor graphs using LangGraph create_react_agent (hierarchical pattern) - Strip ChatAgent classes from all 4 agent files, keep @tool functions - Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream - Update device_ws.py to use run_home_stream/run_floating_stream - Rewrite chat.py REST route to use run_home - Add update_core_memory tool to both supervisors - Add langgraph>=0.3.0 to requirements.txt - Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py - Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas - Update all affected tests to match new API - Remove 6 deprecated test files for deleted modules - Clean up stale docstrings referencing removed orchestrator
232 lines
8.3 KiB
Python
232 lines
8.3 KiB
Python
"""Memory Middleware — enrich requests with memory context and store interactions.
|
|
|
|
Four-tier memory model (MemGPT-style):
|
|
core — persistent key/value user preferences, always injected
|
|
associative — semantic similarity search via pgvector (top-k)
|
|
episodic — recent session summaries (last N)
|
|
proactive — behavioral patterns above confidence threshold
|
|
|
|
All memory content is encrypted at rest using the per-user Fernet key
|
|
stored in User.encryption_key. Decryption happens in-memory only.
|
|
|
|
Usage:
|
|
memory = MemoryMiddleware(db_session)
|
|
context = await memory.enrich_context(user_id, message)
|
|
# ... run agent ...
|
|
await memory.store_episode(user_id, session_id, message, response)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models import (
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
User,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Tuning constants
|
|
_ASSOCIATIVE_TOP_K = 5
|
|
_EPISODIC_RECENT_N = 10
|
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|
|
|
|
|
class MemoryMiddleware:
|
|
"""Enrich agent context with memory and persist interactions after."""
|
|
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self._db = db
|
|
|
|
# ── Public API ────────────────────────────────────────────────────────────
|
|
|
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
|
"""Build memory context dict to inject into the agent before LLM call.
|
|
|
|
Returns a dict with keys:
|
|
core_memory — {key: plaintext_value, ...}
|
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
|
"""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return {}
|
|
|
|
core = await self._load_core(user_id, fernet)
|
|
associative = await self._load_associative(user_id, message, fernet)
|
|
episodic = await self._load_episodic(user_id, fernet)
|
|
proactive = await self._load_proactive(user_id, fernet)
|
|
|
|
return {
|
|
"core_memory": core,
|
|
"associative_memory": associative,
|
|
"episodic_memory": episodic,
|
|
"proactive_hints": proactive,
|
|
}
|
|
|
|
async def store_episode(
|
|
self,
|
|
user_id: str,
|
|
session_id: str,
|
|
message: str,
|
|
response: str,
|
|
) -> None:
|
|
"""Summarise and store a completed interaction in episodic memory.
|
|
|
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
|
latency low. Full LLM summarisation can be added in a later step.
|
|
"""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
|
encrypted = _encrypt(fernet, summary)
|
|
|
|
row = MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
summary_encrypted=encrypted,
|
|
session_id=session_id,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
except Exception as exc:
|
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
|
"""Upsert a core memory key/value for a user."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, value)
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == key,
|
|
)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is not None:
|
|
existing.value_encrypted = encrypted
|
|
else:
|
|
self._db.add(MemoryCore(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
key=key,
|
|
value_encrypted=encrypted,
|
|
))
|
|
try:
|
|
await self._db.commit()
|
|
except Exception as exc:
|
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
|
await self._db.rollback()
|
|
|
|
# ── Private helpers ───────────────────────────────────────────────────────
|
|
|
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
|
return None
|
|
return Fernet(user.encryption_key.encode())
|
|
|
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: dict[str, str] = {}
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out[row.key] = plaintext
|
|
return out
|
|
|
|
async def _load_associative(
|
|
self, user_id: str, message: str, fernet: Fernet
|
|
) -> list[str]:
|
|
"""Load top-k associative memories.
|
|
|
|
Production: uses pgvector cosine similarity on the message embedding.
|
|
Current implementation: keyword-based fallback (no external embedding call)
|
|
so tests pass without a live OpenAI key.
|
|
"""
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative)
|
|
.where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc())
|
|
.limit(_ASSOCIATIVE_TOP_K)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryEpisodic)
|
|
.where(MemoryEpisodic.user_id == user_id)
|
|
.order_by(MemoryEpisodic.created_at.desc())
|
|
.limit(_EPISODIC_RECENT_N)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryProactive)
|
|
.where(
|
|
MemoryProactive.user_id == user_id,
|
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
|
)
|
|
.order_by(MemoryProactive.confidence.desc())
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
|
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
|
|
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
|
return fernet.encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
|
try:
|
|
return fernet.decrypt(ciphertext.encode()).decode()
|
|
except (InvalidToken, Exception) as exc:
|
|
logger.warning("memory: decrypt failed: %s", exc)
|
|
return None
|