diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index 8a8bd29..9d8f70d 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -36,6 +36,11 @@ class WorkerTask(BaseModel): instruction: str +class MemoryUpdate(BaseModel): + key: str = Field(description="The memory key to set or update.") + value: str = Field(description="The persistent fact or preference value.") + + class WorkerSummary(BaseModel): summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.") @@ -43,6 +48,7 @@ class WorkerSummary(BaseModel): class WorkerPlan(BaseModel): tasks: list[WorkerTask] = Field(default_factory=list) floating_domain: FloatingDomain | None = None + memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.") class WorkerResult(TypedDict): @@ -345,37 +351,12 @@ async def _stream_with_memory_tool( system_prompt: str, user_prompt: str, ) -> str: - @tool - async def update_core_memory(key: str, value: str) -> str: - """Save stable user preference/profile data to core memory.""" - async with async_session() as db: - memory = MemoryMiddleware(db) - await memory.update_core(user_id, key, value) - return f"Saved core memory key '{key}'." - llm = get_llm() messages: list[Any] = [ SystemMessage(content=system_prompt), HumanMessage(content=user_prompt), ] - llm_with_tools = llm.bind_tools([update_core_memory]) - - for _ in range(2): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) - - if not response.tool_calls: - break - - for call in response.tool_calls: - if call["name"] != "update_core_memory": - messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"])) - continue - - tool_output = await update_core_memory.ainvoke(call.get("args", {})) - messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - chunks: list[str] = [] async for chunk in llm.astream(messages): token = _as_text(getattr(chunk, "content", "")) @@ -402,13 +383,31 @@ def _synthesizer_node(floating: bool): return _node +async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]: + if not updates: + return current_memory + + new_memory = dict(current_memory) + async with async_session() as db: + memory = MemoryMiddleware(db) + for update in updates: + await memory.update_core(user_id, update.key, update.value) + new_memory[update.key] = update.value + return new_memory + async def _orchestrator_node_home(state: GraphState) -> GraphState: if state.get("plan"): return {} context = {**state.get("context", {}), **state.get("memory_context", {})} plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False) - return {"plan": [task.model_dump() for task in plan.tasks]} + + new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {})) + + return { + "plan": [task.model_dump() for task in plan.tasks], + "memory_context": new_memory + } async def _orchestrator_node_floating(state: GraphState) -> GraphState: @@ -421,9 +420,12 @@ async def _orchestrator_node_floating(state: GraphState) -> GraphState: if floating_domain is None and plan.tasks: floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] + new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {})) + return { "plan": [task.model_dump() for task in plan.tasks], "floating_domain": floating_domain or "tasks", + "memory_context": new_memory } @@ -482,13 +484,14 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: plan = await _plan_with_llm(message, context, floating=True) domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] + new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context) state = await FLOATING_GRAPH.ainvoke( { "user_id": user_id, "user_message": message, "context": context, - "memory_context": context, + "memory_context": new_memory, "plan": [task.model_dump() for task in plan.tasks], "floating_domain": domain, "worker_results": [], @@ -531,11 +534,13 @@ async def run_floating_stream( domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] yield "floating_domain", domain + new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context) + state_input = { "user_id": user_id, "user_message": message, "context": context, - "memory_context": context, + "memory_context": new_memory, "plan": [t.model_dump() for t in plan.tasks], "floating_domain": domain, "worker_results": [],