refactor: move memory updates from synthesizer to orchestrator node
This commit is contained in:
@@ -36,6 +36,11 @@ class WorkerTask(BaseModel):
|
||||
instruction: str
|
||||
|
||||
|
||||
class MemoryUpdate(BaseModel):
|
||||
key: str = Field(description="The memory key to set or update.")
|
||||
value: str = Field(description="The persistent fact or preference value.")
|
||||
|
||||
|
||||
class WorkerSummary(BaseModel):
|
||||
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
|
||||
|
||||
@@ -43,6 +48,7 @@ class WorkerSummary(BaseModel):
|
||||
class WorkerPlan(BaseModel):
|
||||
tasks: list[WorkerTask] = Field(default_factory=list)
|
||||
floating_domain: FloatingDomain | None = None
|
||||
memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.")
|
||||
|
||||
|
||||
class WorkerResult(TypedDict):
|
||||
@@ -345,37 +351,12 @@ async def _stream_with_memory_tool(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
) -> str:
|
||||
@tool
|
||||
async def update_core_memory(key: str, value: str) -> str:
|
||||
"""Save stable user preference/profile data to core memory."""
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, key, value)
|
||||
return f"Saved core memory key '{key}'."
|
||||
|
||||
llm = get_llm()
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=user_prompt),
|
||||
]
|
||||
|
||||
llm_with_tools = llm.bind_tools([update_core_memory])
|
||||
|
||||
for _ in range(2):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
break
|
||||
|
||||
for call in response.tool_calls:
|
||||
if call["name"] != "update_core_memory":
|
||||
messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
|
||||
continue
|
||||
|
||||
tool_output = await update_core_memory.ainvoke(call.get("args", {}))
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in llm.astream(messages):
|
||||
token = _as_text(getattr(chunk, "content", ""))
|
||||
@@ -402,13 +383,31 @@ def _synthesizer_node(floating: bool):
|
||||
return _node
|
||||
|
||||
|
||||
async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]:
|
||||
if not updates:
|
||||
return current_memory
|
||||
|
||||
new_memory = dict(current_memory)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
for update in updates:
|
||||
await memory.update_core(user_id, update.key, update.value)
|
||||
new_memory[update.key] = update.value
|
||||
return new_memory
|
||||
|
||||
async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
||||
if state.get("plan"):
|
||||
return {}
|
||||
|
||||
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
||||
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
||||
return {"plan": [task.model_dump() for task in plan.tasks]}
|
||||
|
||||
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
||||
|
||||
return {
|
||||
"plan": [task.model_dump() for task in plan.tasks],
|
||||
"memory_context": new_memory
|
||||
}
|
||||
|
||||
|
||||
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
||||
@@ -421,9 +420,12 @@ async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
||||
if floating_domain is None and plan.tasks:
|
||||
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||
|
||||
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
||||
|
||||
return {
|
||||
"plan": [task.model_dump() for task in plan.tasks],
|
||||
"floating_domain": floating_domain or "tasks",
|
||||
"memory_context": new_memory
|
||||
}
|
||||
|
||||
|
||||
@@ -482,13 +484,14 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
||||
plan = await _plan_with_llm(message, context, floating=True)
|
||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
|
||||
|
||||
state = await FLOATING_GRAPH.ainvoke(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_message": message,
|
||||
"context": context,
|
||||
"memory_context": context,
|
||||
"memory_context": new_memory,
|
||||
"plan": [task.model_dump() for task in plan.tasks],
|
||||
"floating_domain": domain,
|
||||
"worker_results": [],
|
||||
@@ -531,11 +534,13 @@ async def run_floating_stream(
|
||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||
yield "floating_domain", domain
|
||||
|
||||
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
|
||||
|
||||
state_input = {
|
||||
"user_id": user_id,
|
||||
"user_message": message,
|
||||
"context": context,
|
||||
"memory_context": context,
|
||||
"memory_context": new_memory,
|
||||
"plan": [t.model_dump() for t in plan.tasks],
|
||||
"floating_domain": domain,
|
||||
"worker_results": [],
|
||||
|
||||
Reference in New Issue
Block a user