refactor: move memory updates from synthesizer to orchestrator node
This commit is contained in:
@@ -36,6 +36,11 @@ class WorkerTask(BaseModel):
|
|||||||
instruction: str
|
instruction: str
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryUpdate(BaseModel):
|
||||||
|
key: str = Field(description="The memory key to set or update.")
|
||||||
|
value: str = Field(description="The persistent fact or preference value.")
|
||||||
|
|
||||||
|
|
||||||
class WorkerSummary(BaseModel):
|
class WorkerSummary(BaseModel):
|
||||||
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
|
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
|
||||||
|
|
||||||
@@ -43,6 +48,7 @@ class WorkerSummary(BaseModel):
|
|||||||
class WorkerPlan(BaseModel):
|
class WorkerPlan(BaseModel):
|
||||||
tasks: list[WorkerTask] = Field(default_factory=list)
|
tasks: list[WorkerTask] = Field(default_factory=list)
|
||||||
floating_domain: FloatingDomain | None = None
|
floating_domain: FloatingDomain | None = None
|
||||||
|
memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.")
|
||||||
|
|
||||||
|
|
||||||
class WorkerResult(TypedDict):
|
class WorkerResult(TypedDict):
|
||||||
@@ -345,37 +351,12 @@ async def _stream_with_memory_tool(
|
|||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
@tool
|
|
||||||
async def update_core_memory(key: str, value: str) -> str:
|
|
||||||
"""Save stable user preference/profile data to core memory."""
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.update_core(user_id, key, value)
|
|
||||||
return f"Saved core memory key '{key}'."
|
|
||||||
|
|
||||||
llm = get_llm()
|
llm = get_llm()
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(content=user_prompt),
|
HumanMessage(content=user_prompt),
|
||||||
]
|
]
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools([update_core_memory])
|
|
||||||
|
|
||||||
for _ in range(2):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
break
|
|
||||||
|
|
||||||
for call in response.tool_calls:
|
|
||||||
if call["name"] != "update_core_memory":
|
|
||||||
messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_output = await update_core_memory.ainvoke(call.get("args", {}))
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
chunks: list[str] = []
|
chunks: list[str] = []
|
||||||
async for chunk in llm.astream(messages):
|
async for chunk in llm.astream(messages):
|
||||||
token = _as_text(getattr(chunk, "content", ""))
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
@@ -402,13 +383,31 @@ def _synthesizer_node(floating: bool):
|
|||||||
return _node
|
return _node
|
||||||
|
|
||||||
|
|
||||||
|
async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
if not updates:
|
||||||
|
return current_memory
|
||||||
|
|
||||||
|
new_memory = dict(current_memory)
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
for update in updates:
|
||||||
|
await memory.update_core(user_id, update.key, update.value)
|
||||||
|
new_memory[update.key] = update.value
|
||||||
|
return new_memory
|
||||||
|
|
||||||
async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
||||||
if state.get("plan"):
|
if state.get("plan"):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
||||||
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
||||||
return {"plan": [task.model_dump() for task in plan.tasks]}
|
|
||||||
|
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"plan": [task.model_dump() for task in plan.tasks],
|
||||||
|
"memory_context": new_memory
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
||||||
@@ -421,9 +420,12 @@ async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
|||||||
if floating_domain is None and plan.tasks:
|
if floating_domain is None and plan.tasks:
|
||||||
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||||
|
|
||||||
|
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"plan": [task.model_dump() for task in plan.tasks],
|
"plan": [task.model_dump() for task in plan.tasks],
|
||||||
"floating_domain": floating_domain or "tasks",
|
"floating_domain": floating_domain or "tasks",
|
||||||
|
"memory_context": new_memory
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -482,13 +484,14 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
||||||
plan = await _plan_with_llm(message, context, floating=True)
|
plan = await _plan_with_llm(message, context, floating=True)
|
||||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||||
|
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
|
||||||
|
|
||||||
state = await FLOATING_GRAPH.ainvoke(
|
state = await FLOATING_GRAPH.ainvoke(
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"user_message": message,
|
"user_message": message,
|
||||||
"context": context,
|
"context": context,
|
||||||
"memory_context": context,
|
"memory_context": new_memory,
|
||||||
"plan": [task.model_dump() for task in plan.tasks],
|
"plan": [task.model_dump() for task in plan.tasks],
|
||||||
"floating_domain": domain,
|
"floating_domain": domain,
|
||||||
"worker_results": [],
|
"worker_results": [],
|
||||||
@@ -531,11 +534,13 @@ async def run_floating_stream(
|
|||||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||||
yield "floating_domain", domain
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
|
||||||
|
|
||||||
state_input = {
|
state_input = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"user_message": message,
|
"user_message": message,
|
||||||
"context": context,
|
"context": context,
|
||||||
"memory_context": context,
|
"memory_context": new_memory,
|
||||||
"plan": [t.model_dump() for t in plan.tasks],
|
"plan": [t.model_dump() for t in plan.tasks],
|
||||||
"floating_domain": domain,
|
"floating_domain": domain,
|
||||||
"worker_results": [],
|
"worker_results": [],
|
||||||
|
|||||||
Reference in New Issue
Block a user