feat: add letta-style memory tools with request/user debug tracing
This commit is contained in:
@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
@@ -24,17 +25,19 @@ logger = logging.getLogger(__name__)
|
||||
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
||||
|
||||
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
||||
)
|
||||
|
||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
||||
@@ -118,6 +121,158 @@ def _all_tools() -> list[Any]:
|
||||
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||
|
||||
|
||||
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||
debug = context.get("_debug")
|
||||
if isinstance(debug, dict):
|
||||
request_id = debug.get("request_id")
|
||||
if isinstance(request_id, str) and request_id:
|
||||
return request_id
|
||||
return None
|
||||
|
||||
|
||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||
sanitized = dict(context)
|
||||
sanitized.pop("_debug", None)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _normalize_memory_label(path_or_label: str) -> str:
|
||||
value = path_or_label.strip()
|
||||
if value.startswith("/memories/"):
|
||||
value = value[len("/memories/"):]
|
||||
value = value.strip("/")
|
||||
return value
|
||||
|
||||
|
||||
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
@tool
|
||||
async def memory_list_blocks() -> str:
|
||||
"""List all core memory blocks currently stored for the user."""
|
||||
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
blocks = await memory.list_core_blocks(user_id)
|
||||
if not blocks:
|
||||
return "No memory blocks found."
|
||||
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||
return "Memory blocks:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def memory_get(path_or_label: str) -> str:
|
||||
"""Get one memory block by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
value = await memory.get_core_block(user_id, label)
|
||||
if value is None:
|
||||
return f"Memory block '{label}' not found."
|
||||
return f"Memory block '{label}':\n{value}"
|
||||
|
||||
@tool
|
||||
async def memory_create(path_or_label: str, value: str) -> str:
|
||||
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||
return f"Memory block '{label}' saved."
|
||||
|
||||
@tool
|
||||
async def memory_append(path_or_label: str, content: str) -> str:
|
||||
"""Append content to a memory block, creating it if missing."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.append_core(user_id, label, content)
|
||||
return f"Memory block '{label}' appended."
|
||||
|
||||
@tool
|
||||
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||
"""Replace one exact string in a memory block."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||
if not changed:
|
||||
return f"No replacement made in '{label}' (old string not found)."
|
||||
return f"Memory block '{label}' updated."
|
||||
|
||||
@tool
|
||||
async def memory_delete(path_or_label: str) -> str:
|
||||
"""Delete a memory block by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
deleted = await memory.delete_core(user_id, label)
|
||||
if not deleted:
|
||||
return f"Memory block '{label}' not found."
|
||||
return f"Memory block '{label}' deleted."
|
||||
|
||||
@tool
|
||||
async def archival_memory_insert(content: str) -> str:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.insert_archival(user_id, content, source="assistant")
|
||||
return "Archival memory saved."
|
||||
|
||||
@tool
|
||||
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||
if not results:
|
||||
return "No archival memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Archival memory results:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search recall memory from prior episodic conversation summaries."""
|
||||
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||
if not results:
|
||||
return "No recall memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Recall memory results:\n" + "\n".join(lines)
|
||||
|
||||
return [
|
||||
memory_list_blocks,
|
||||
memory_get,
|
||||
memory_create,
|
||||
memory_append,
|
||||
memory_replace,
|
||||
memory_delete,
|
||||
archival_memory_insert,
|
||||
archival_memory_search,
|
||||
conversation_search,
|
||||
]
|
||||
|
||||
|
||||
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||
|
||||
|
||||
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
|
||||
scope = context.get("scope") if isinstance(context, dict) else None
|
||||
if isinstance(scope, dict):
|
||||
@@ -143,20 +298,24 @@ def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDom
|
||||
|
||||
async def _run_single_agent(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
max_steps: int = 6,
|
||||
) -> str:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
llm = get_llm()
|
||||
tools = _all_tools()
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"User message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
|
||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
@@ -206,20 +365,24 @@ async def _run_single_agent(
|
||||
|
||||
async def _run_single_agent_stream(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
max_steps: int = 6,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
llm = get_llm()
|
||||
tools = _all_tools()
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"User message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
|
||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
@@ -276,6 +439,7 @@ async def _run_single_agent_stream(
|
||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
return await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
@@ -286,6 +450,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
||||
domain = _infer_floating_domain(message, context)
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
@@ -300,6 +465,7 @@ async def run_home_stream(
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
@@ -317,6 +483,7 @@ async def run_floating_stream(
|
||||
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
|
||||
Reference in New Issue
Block a user