feat: add letta-style memory tools with request/user debug tracing
This commit is contained in:
@@ -223,10 +223,11 @@ async def _handle_home_request(
|
|||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(user_id, message, trace_id=request_id)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -253,7 +254,7 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -271,9 +272,13 @@ async def _handle_floating_request(
|
|||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(user_id, message, trace_id=request_id)
|
||||||
|
|
||||||
context: dict = {"scope": scope, **memory_context}
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
@@ -297,7 +302,7 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.agents.note_agent import NOTE_TOOLS
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
from app.agents.project_agent import PROJECT_TOOLS
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
@@ -24,17 +25,19 @@ logger = logging.getLogger(__name__)
|
|||||||
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
||||||
|
|
||||||
_HOME_SINGLE_AGENT_SYSTEM = (
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
||||||
)
|
)
|
||||||
|
|
||||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
||||||
@@ -118,6 +121,158 @@ def _all_tools() -> list[Any]:
|
|||||||
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
request_id = debug.get("request_id")
|
||||||
|
if isinstance(request_id, str) and request_id:
|
||||||
|
return request_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
sanitized = dict(context)
|
||||||
|
sanitized.pop("_debug", None)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
|
value = path_or_label.strip()
|
||||||
|
if value.startswith("/memories/"):
|
||||||
|
value = value[len("/memories/"):]
|
||||||
|
value = value.strip("/")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
@tool
|
||||||
|
async def memory_list_blocks() -> str:
|
||||||
|
"""List all core memory blocks currently stored for the user."""
|
||||||
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
blocks = await memory.list_core_blocks(user_id)
|
||||||
|
if not blocks:
|
||||||
|
return "No memory blocks found."
|
||||||
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||||
|
return "Memory blocks:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_get(path_or_label: str) -> str:
|
||||||
|
"""Get one memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
value = await memory.get_core_block(user_id, label)
|
||||||
|
if value is None:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}':\n{value}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_create(path_or_label: str, value: str) -> str:
|
||||||
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||||
|
return f"Memory block '{label}' saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_append(path_or_label: str, content: str) -> str:
|
||||||
|
"""Append content to a memory block, creating it if missing."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.append_core(user_id, label, content)
|
||||||
|
return f"Memory block '{label}' appended."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||||
|
"""Replace one exact string in a memory block."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||||
|
if not changed:
|
||||||
|
return f"No replacement made in '{label}' (old string not found)."
|
||||||
|
return f"Memory block '{label}' updated."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def memory_delete(path_or_label: str) -> str:
|
||||||
|
"""Delete a memory block by label or /memories/<label> path."""
|
||||||
|
label = _normalize_memory_label(path_or_label)
|
||||||
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||||
|
if not label:
|
||||||
|
return "Invalid memory label."
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
deleted = await memory.delete_core(user_id, label)
|
||||||
|
if not deleted:
|
||||||
|
return f"Memory block '{label}' not found."
|
||||||
|
return f"Memory block '{label}' deleted."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_insert(content: str) -> str:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.insert_archival(user_id, content, source="assistant")
|
||||||
|
return "Archival memory saved."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||||
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No archival memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Archival memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||||
|
"""Search recall memory from prior episodic conversation summaries."""
|
||||||
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||||
|
if not results:
|
||||||
|
return "No recall memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return [
|
||||||
|
memory_list_blocks,
|
||||||
|
memory_get,
|
||||||
|
memory_create,
|
||||||
|
memory_append,
|
||||||
|
memory_replace,
|
||||||
|
memory_delete,
|
||||||
|
archival_memory_insert,
|
||||||
|
archival_memory_search,
|
||||||
|
conversation_search,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||||
|
|
||||||
|
|
||||||
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
|
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
|
||||||
scope = context.get("scope") if isinstance(context, dict) else None
|
scope = context.get("scope") if isinstance(context, dict) else None
|
||||||
if isinstance(scope, dict):
|
if isinstance(scope, dict):
|
||||||
@@ -143,20 +298,24 @@ def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDom
|
|||||||
|
|
||||||
async def _run_single_agent(
|
async def _run_single_agent(
|
||||||
*,
|
*,
|
||||||
|
user_id: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
llm = get_llm()
|
llm = get_llm()
|
||||||
tools = _all_tools()
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=(
|
content=(
|
||||||
f"User message:\n{message}\n\n"
|
f"User message:\n{message}\n\n"
|
||||||
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -206,20 +365,24 @@ async def _run_single_agent(
|
|||||||
|
|
||||||
async def _run_single_agent_stream(
|
async def _run_single_agent_stream(
|
||||||
*,
|
*,
|
||||||
|
user_id: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
llm = get_llm()
|
llm = get_llm()
|
||||||
tools = _all_tools()
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
|
model_context = _context_for_model(context)
|
||||||
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=(
|
content=(
|
||||||
f"User message:\n{message}\n\n"
|
f"User message:\n{message}\n\n"
|
||||||
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -276,6 +439,7 @@ async def _run_single_agent_stream(
|
|||||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
return await _run_single_agent(
|
return await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
@@ -286,6 +450,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
domain = _infer_floating_domain(message, context)
|
domain = _infer_floating_domain(message, context)
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
|
user_id=user_id,
|
||||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
@@ -300,6 +465,7 @@ async def run_home_stream(
|
|||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
@@ -317,6 +483,7 @@ async def run_floating_stream(
|
|||||||
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
async def enrich_context(self, user_id: str, message: str, trace_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
@@ -68,6 +68,19 @@ class MemoryMiddleware:
|
|||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s email=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("email") or "-",
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
len(core),
|
||||||
|
len(associative),
|
||||||
|
len(episodic),
|
||||||
|
len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -81,6 +94,7 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -103,11 +117,20 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: store_episode trace=%s user=%s email=%s tier=%s session=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("email") or "-",
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -133,10 +156,177 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: update_core trace=%s user=%s email=%s tier=%s key=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("email") or "-",
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
key,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Return core memory as editable blocks (label/value)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore)
|
||||||
|
.where(MemoryCore.user_id == user_id)
|
||||||
|
.order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
"""Return a single core memory block value by label."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||||
|
return None
|
||||||
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
"""Delete a core memory block by label. Returns True if deleted."""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
"""Append content to a core block, creating it if missing."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=None,
|
||||||
|
entity_type=source,
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search recall memory (episodic summaries) by keyword."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -148,6 +338,19 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
return {"email": None, "tier": None, "name": None, "surname": None}
|
||||||
|
return {
|
||||||
|
"email": user.email,
|
||||||
|
"tier": user.tier,
|
||||||
|
"name": user.name,
|
||||||
|
"surname": user.surname,
|
||||||
|
}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
|||||||
@@ -229,6 +229,40 @@ async def test_update_core_upsert(db_session, user_with_key):
|
|||||||
assert _dec(rows[0].value_encrypted) == "fr"
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_core_block_edit_ops(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
||||||
|
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
||||||
|
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
||||||
|
|
||||||
|
blocks = await middleware.list_core_blocks(USER_ID)
|
||||||
|
human = next(b for b in blocks if b["label"] == "human")
|
||||||
|
|
||||||
|
assert replaced is True
|
||||||
|
assert "Name: Robert" in human["value"]
|
||||||
|
assert "Timezone: Europe/Rome" in human["value"]
|
||||||
|
|
||||||
|
deleted = await middleware.delete_core(USER_ID, "human")
|
||||||
|
assert deleted is True
|
||||||
|
assert await middleware.get_core_block(USER_ID, "human") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
||||||
|
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
||||||
|
|
||||||
|
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
||||||
|
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
||||||
|
|
||||||
|
assert any("whitelist" in item.lower() for item in arch)
|
||||||
|
assert any("delayed" in item.lower() for item in rec)
|
||||||
|
|
||||||
|
|
||||||
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
def test_home_request_calls_memory_middleware(client):
|
def test_home_request_calls_memory_middleware(client):
|
||||||
|
|||||||
Reference in New Issue
Block a user