refactor: move memory updates from synthesizer to orchestrator node

This commit is contained in:
2026-03-12 23:03:38 +01:00
parent d667e43c73
commit f7404b6f66

View File

@@ -36,6 +36,11 @@ class WorkerTask(BaseModel):
instruction: str instruction: str
class MemoryUpdate(BaseModel):
key: str = Field(description="The memory key to set or update.")
value: str = Field(description="The persistent fact or preference value.")
class WorkerSummary(BaseModel): class WorkerSummary(BaseModel):
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.") summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
@@ -43,6 +48,7 @@ class WorkerSummary(BaseModel):
class WorkerPlan(BaseModel): class WorkerPlan(BaseModel):
tasks: list[WorkerTask] = Field(default_factory=list) tasks: list[WorkerTask] = Field(default_factory=list)
floating_domain: FloatingDomain | None = None floating_domain: FloatingDomain | None = None
memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.")
class WorkerResult(TypedDict): class WorkerResult(TypedDict):
@@ -345,37 +351,12 @@ async def _stream_with_memory_tool(
system_prompt: str, system_prompt: str,
user_prompt: str, user_prompt: str,
) -> str: ) -> str:
@tool
async def update_core_memory(key: str, value: str) -> str:
"""Save stable user preference/profile data to core memory."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)
return f"Saved core memory key '{key}'."
llm = get_llm() llm = get_llm()
messages: list[Any] = [ messages: list[Any] = [
SystemMessage(content=system_prompt), SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt), HumanMessage(content=user_prompt),
] ]
llm_with_tools = llm.bind_tools([update_core_memory])
for _ in range(2):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
break
for call in response.tool_calls:
if call["name"] != "update_core_memory":
messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
continue
tool_output = await update_core_memory.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
chunks: list[str] = [] chunks: list[str] = []
async for chunk in llm.astream(messages): async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", "")) token = _as_text(getattr(chunk, "content", ""))
@@ -402,13 +383,31 @@ def _synthesizer_node(floating: bool):
return _node return _node
async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]:
if not updates:
return current_memory
new_memory = dict(current_memory)
async with async_session() as db:
memory = MemoryMiddleware(db)
for update in updates:
await memory.update_core(user_id, update.key, update.value)
new_memory[update.key] = update.value
return new_memory
async def _orchestrator_node_home(state: GraphState) -> GraphState: async def _orchestrator_node_home(state: GraphState) -> GraphState:
if state.get("plan"): if state.get("plan"):
return {} return {}
context = {**state.get("context", {}), **state.get("memory_context", {})} context = {**state.get("context", {}), **state.get("memory_context", {})}
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False) plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
return {"plan": [task.model_dump() for task in plan.tasks]}
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
return {
"plan": [task.model_dump() for task in plan.tasks],
"memory_context": new_memory
}
async def _orchestrator_node_floating(state: GraphState) -> GraphState: async def _orchestrator_node_floating(state: GraphState) -> GraphState:
@@ -421,9 +420,12 @@ async def _orchestrator_node_floating(state: GraphState) -> GraphState:
if floating_domain is None and plan.tasks: if floating_domain is None and plan.tasks:
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
return { return {
"plan": [task.model_dump() for task in plan.tasks], "plan": [task.model_dump() for task in plan.tasks],
"floating_domain": floating_domain or "tasks", "floating_domain": floating_domain or "tasks",
"memory_context": new_memory
} }
@@ -482,13 +484,14 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
plan = await _plan_with_llm(message, context, floating=True) plan = await _plan_with_llm(message, context, floating=True)
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
state = await FLOATING_GRAPH.ainvoke( state = await FLOATING_GRAPH.ainvoke(
{ {
"user_id": user_id, "user_id": user_id,
"user_message": message, "user_message": message,
"context": context, "context": context,
"memory_context": context, "memory_context": new_memory,
"plan": [task.model_dump() for task in plan.tasks], "plan": [task.model_dump() for task in plan.tasks],
"floating_domain": domain, "floating_domain": domain,
"worker_results": [], "worker_results": [],
@@ -531,11 +534,13 @@ async def run_floating_stream(
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
yield "floating_domain", domain yield "floating_domain", domain
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
state_input = { state_input = {
"user_id": user_id, "user_id": user_id,
"user_message": message, "user_message": message,
"context": context, "context": context,
"memory_context": context, "memory_context": new_memory,
"plan": [t.model_dump() for t in plan.tasks], "plan": [t.model_dump() for t in plan.tasks],
"floating_domain": domain, "floating_domain": domain,
"worker_results": [], "worker_results": [],