- Add app/core/deep_agent.py with Home and Floating supervisor graphs using LangGraph create_react_agent (hierarchical pattern) - Strip ChatAgent classes from all 4 agent files, keep @tool functions - Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream - Update device_ws.py to use run_home_stream/run_floating_stream - Rewrite chat.py REST route to use run_home - Add update_core_memory tool to both supervisors - Add langgraph>=0.3.0 to requirements.txt - Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py - Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas - Update all affected tests to match new API - Remove 6 deprecated test files for deleted modules - Clean up stale docstrings referencing removed orchestrator
285 lines
9.7 KiB
Python
285 lines
9.7 KiB
Python
"""Tests for Step 7 — MemoryMiddleware.
|
|
|
|
Coverage:
|
|
1. enrich_context returns core prefs + associative + episodic + proactive
|
|
2. store_episode creates an encrypted row decryptable with the user's key
|
|
3. update_core upserts correctly
|
|
4. User with no encryption_key returns empty context (no crash)
|
|
5. End-to-end: home_request WS frame results in an episodic row being stored
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
|
|
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import (
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
User,
|
|
)
|
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
|
|
|
|
|
USER_ID = TEST_USER_IDS["power"]
|
|
_FERNET_KEY = Fernet.generate_key().decode()
|
|
|
|
|
|
# ── DB override ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture
|
|
async def user_with_key(db_session):
|
|
"""Set encryption_key on the seeded power user."""
|
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
def _fernet():
|
|
return Fernet(_FERNET_KEY.encode())
|
|
|
|
|
|
def _enc(plaintext: str) -> str:
|
|
return _fernet().encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
def _dec(ciphertext: str) -> str:
|
|
return _fernet().decrypt(ciphertext.encode()).decode()
|
|
|
|
|
|
# ── enrich_context ────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
|
|
# Seed a core memory row
|
|
db_session.add(MemoryCore(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
key="timezone",
|
|
value_encrypted=_enc("UTC"),
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
|
|
|
|
assert "core_memory" in ctx
|
|
assert ctx["core_memory"]["timezone"] == "UTC"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
|
|
session_id = str(uuid.uuid4())
|
|
db_session.add(MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
summary_encrypted=_enc("User asked about Q1 tasks"),
|
|
session_id=session_id,
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
|
|
|
assert "episodic_memory" in ctx
|
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
|
# Add one pattern above threshold and one below
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc("User prefers short summaries"),
|
|
confidence=0.9,
|
|
source="inferred",
|
|
))
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc("User likes dark mode"),
|
|
confidence=0.1,
|
|
source="inferred",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
|
|
|
assert "proactive_hints" in ctx
|
|
hints = ctx["proactive_hints"]
|
|
assert any("short summaries" in h for h in hints)
|
|
assert not any("dark mode" in h for h in hints)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
|
|
db_session.add(MemoryAssociative(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
content_encrypted=_enc("Related memory about meetings"),
|
|
embedding=None,
|
|
entity_type="note",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "meetings")
|
|
|
|
assert "associative_memory" in ctx
|
|
assert any("meetings" in m for m in ctx["associative_memory"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_empty_for_user_without_key(db_session):
|
|
"""User with no encryption_key → empty context, no crash."""
|
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = None
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "hello")
|
|
assert ctx == {}
|
|
|
|
|
|
# ── store_episode ─────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
|
|
session_id = str(uuid.uuid4())
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.store_episode(USER_ID, session_id, "hello", "world")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
|
)
|
|
row = result.scalar_one()
|
|
plaintext = _dec(row.summary_encrypted)
|
|
assert "hello" in plaintext
|
|
assert "world" in plaintext
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_episode_decryptable(db_session, user_with_key):
|
|
session_id = str(uuid.uuid4())
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
|
)
|
|
row = result.scalar_one()
|
|
# Decrypt using the same key — must not raise
|
|
decrypted = _dec(row.summary_encrypted)
|
|
assert len(decrypted) > 0
|
|
|
|
|
|
# ── update_core ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_core_insert(db_session, user_with_key):
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.update_core(USER_ID, "lang", "en")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
|
)
|
|
row = result.scalar_one()
|
|
assert _dec(row.value_encrypted) == "en"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_core_upsert(db_session, user_with_key):
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.update_core(USER_ID, "lang", "en")
|
|
await middleware.update_core(USER_ID, "lang", "fr")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
|
)
|
|
rows = result.scalars().all()
|
|
assert len(rows) == 1
|
|
assert _dec(rows[0].value_encrypted) == "fr"
|
|
|
|
|
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
|
|
|
def test_home_request_calls_memory_middleware(client):
|
|
"""home_request triggers enrich_context before and store_episode after the LLM."""
|
|
enrich_calls: list[tuple] = []
|
|
store_calls: list[tuple] = []
|
|
|
|
class _MockMiddleware:
|
|
def __init__(self, db):
|
|
pass
|
|
|
|
async def enrich_context(self, user_id, message):
|
|
enrich_calls.append((user_id, message))
|
|
return {"core_memory": {"tz": "UTC"}}
|
|
|
|
async def store_episode(self, user_id, session_id, message, response):
|
|
store_calls.append((user_id, session_id, message, response))
|
|
|
|
token = make_jwt("power", user_id=USER_ID)
|
|
session_id = str(uuid.uuid4())
|
|
|
|
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
|
# Verify memory context was injected
|
|
assert context.get("core_memory") == {"tz": "UTC"}
|
|
yield ("token", "Done")
|
|
yield ("mutations", [])
|
|
|
|
with (
|
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
|
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
|
|
):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({
|
|
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
|
|
}))
|
|
ws.send_text(json.dumps({
|
|
"type": "home_request",
|
|
"request_id": "r-mem",
|
|
"session_id": session_id,
|
|
"message": "Show tasks",
|
|
}))
|
|
for _ in range(20):
|
|
raw = ws.receive_text()
|
|
frame = json.loads(raw)
|
|
if frame.get("type") == "stream_end":
|
|
break
|
|
|
|
assert len(enrich_calls) == 1
|
|
assert enrich_calls[0] == (USER_ID, "Show tasks")
|
|
assert len(store_calls) == 1
|
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
|
assert stored_session_id == session_id
|
|
assert stored_message == "Show tasks"
|