WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)
Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""Memory Middleware — adapted for Chat Service.
|
|
|
|
Uses shared.models instead of app.models. Otherwise identical to the
|
|
monolith's app/core/memory_middleware.py.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from shared.models import (
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
User,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ASSOCIATIVE_TOP_K = 5
|
|
_EPISODIC_RECENT_N = 10
|
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|
|
|
|
|
class MemoryMiddleware:
|
|
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self._db = db
|
|
|
|
async def enrich_context(
|
|
self,
|
|
user_id: str,
|
|
message: str,
|
|
trace_id: str | None = None,
|
|
session_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return {}
|
|
|
|
core = await self._load_core(user_id, fernet)
|
|
associative = await self._load_associative(user_id, message, fernet)
|
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
|
proactive = await self._load_proactive(user_id, fernet)
|
|
|
|
logger.info(
|
|
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
|
|
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
|
|
)
|
|
|
|
return {
|
|
"core_memory": core,
|
|
"associative_memory": associative,
|
|
"episodic_memory": episodic,
|
|
"proactive_hints": proactive,
|
|
}
|
|
|
|
async def store_episode(
|
|
self, user_id: str, session_id: str, message: str, response: str,
|
|
trace_id: str | None = None,
|
|
) -> None:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
|
encrypted = _encrypt(fernet, summary)
|
|
|
|
row = MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
summary_encrypted=encrypted,
|
|
session_id=session_id,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
except Exception as exc:
|
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, value)
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is not None:
|
|
existing.value_encrypted = encrypted
|
|
else:
|
|
self._db.add(MemoryCore(
|
|
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
|
|
))
|
|
try:
|
|
await self._db.commit()
|
|
except Exception as exc:
|
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
|
await self._db.rollback()
|
|
|
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
|
|
)
|
|
out: list[dict[str, str]] = []
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out.append({"label": row.key, "value": plaintext})
|
|
return out
|
|
|
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return None
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
return None
|
|
return _safe_decrypt(fernet, row.value_encrypted)
|
|
|
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
return False
|
|
await self._db.delete(row)
|
|
try:
|
|
await self._db.commit()
|
|
return True
|
|
except Exception as exc:
|
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
await self._db.rollback()
|
|
return False
|
|
|
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
current = await self.get_core_block(user_id, label)
|
|
if current is None:
|
|
await self.update_core(user_id, label, content)
|
|
return
|
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
|
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
current = await self.get_core_block(user_id, label)
|
|
if current is None or old not in current:
|
|
return False
|
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
return True
|
|
|
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
encrypted = _encrypt(fernet, content)
|
|
row = MemoryAssociative(
|
|
id=str(uuid.uuid4()), user_id=user_id,
|
|
content_encrypted=encrypted, embedding=None,
|
|
entity_type=source, entity_id=None,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
except Exception as exc:
|
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
|
|
)
|
|
needle = query.strip().lower()
|
|
out: list[str] = []
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is None:
|
|
continue
|
|
if not needle or needle in plaintext.lower():
|
|
out.append(plaintext)
|
|
if len(out) >= max(top_k, 1):
|
|
break
|
|
return out
|
|
|
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
result = await self._db.execute(
|
|
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
|
|
)
|
|
needle = query.strip().lower()
|
|
out: list[str] = []
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is None:
|
|
continue
|
|
if not needle or needle in plaintext.lower():
|
|
out.append(plaintext)
|
|
if len(out) >= max(top_k, 1):
|
|
break
|
|
return out
|
|
|
|
# ── Private ───────────────────────────────────────────────────────
|
|
|
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
|
return None
|
|
return Fernet(user.encryption_key.encode())
|
|
|
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
|
)
|
|
out: dict[str, str] = {}
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out[row.key] = plaintext
|
|
return out
|
|
|
|
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
|
|
)
|
|
out: list[str] = []
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
|
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
if session_id:
|
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
|
result = await self._db.execute(
|
|
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
|
|
)
|
|
out: list[str] = []
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryProactive).where(
|
|
MemoryProactive.user_id == user_id,
|
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
|
).order_by(MemoryProactive.confidence.desc())
|
|
)
|
|
out: list[str] = []
|
|
for row in result.scalars().all():
|
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
|
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
|
return fernet.encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
|
try:
|
|
return fernet.decrypt(ciphertext.encode()).decode()
|
|
except (InvalidToken, Exception) as exc:
|
|
logger.warning("memory: decrypt failed: %s", exc)
|
|
return None
|