diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index d388ca4..8a8bd29 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -36,6 +36,10 @@ class WorkerTask(BaseModel): instruction: str +class WorkerSummary(BaseModel): + summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.") + + class WorkerPlan(BaseModel): tasks: list[WorkerTask] = Field(default_factory=list) floating_domain: FloatingDomain | None = None @@ -58,7 +62,6 @@ class OrchestratorState(TypedDict, total=False): task: dict[str, Any] worker_results: list[WorkerResult] final_response: str - stream_callback: Callable[[str], Awaitable[None]] | None class GraphState(OrchestratorState): @@ -276,8 +279,13 @@ async def _run_tool_loop( tool_output = await tool_fn.ainvoke(call.get("args", {})) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - final = await llm.ainvoke(messages) - return _as_text(final.content), collected + structured_llm = llm.with_structured_output(WorkerSummary) + messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences.")) + final_summary = await structured_llm.ainvoke(messages) + + if isinstance(final_summary, WorkerSummary): + return final_summary.summary, collected + return str(final_summary), collected finally: clear_tool_result_collector() @@ -336,7 +344,6 @@ async def _stream_with_memory_tool( user_id: str, system_prompt: str, user_prompt: str, - stream_callback: Callable[[str], Awaitable[None]] | None, ) -> str: @tool async def update_core_memory(key: str, value: str) -> str: @@ -375,8 +382,6 @@ async def _stream_with_memory_tool( if not token: continue chunks.append(token) - if stream_callback is not None: - await stream_callback(token) return "".join(chunks) @@ -390,7 +395,6 @@ def _synthesizer_node(floating: bool): user_id=str(state.get("user_id", "")), system_prompt=system_prompt, user_prompt=prompt, - stream_callback=state.get("stream_callback"), ) return {"final_response": final_response} @@ -471,12 +475,10 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: "context": context, "memory_context": context, "worker_results": [], - "stream_callback": None, } ) return str(state.get("final_response", "")) - async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: plan = await _plan_with_llm(message, context, floating=True) domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] @@ -490,7 +492,6 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t "plan": [task.model_dump() for task in plan.tasks], "floating_domain": domain, "worker_results": [], - "stream_callback": None, } ) return str(state.get("final_response", "")), str(domain) @@ -501,37 +502,25 @@ async def run_home_stream( message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: - queue: asyncio.Queue[str] = asyncio.Queue() - - async def _on_token(token: str) -> None: - await queue.put(token) - - task = asyncio.create_task( - HOME_GRAPH.ainvoke( - { - "user_id": user_id, - "user_message": message, - "context": context, - "memory_context": context, - "worker_results": [], - "stream_callback": _on_token, - } - ) - ) - - emitted = False - while not task.done() or not queue.empty(): - try: - token = await asyncio.wait_for(queue.get(), timeout=0.15) - emitted = True - yield "token", token - except asyncio.TimeoutError: - continue - - final_state = await task - if not emitted and final_state.get("final_response"): - yield "token", str(final_state["final_response"]) + state_input = { + "user_id": user_id, + "user_message": message, + "context": context, + "memory_context": context, + "worker_results": [], + } + async for event in HOME_GRAPH.astream_events(state_input, version="v2"): + kind = event["event"] + + if kind == "on_chat_model_stream": + node_name = event.get("metadata", {}).get("langgraph_node") + + if node_name == "synthesizer": + chunk = event["data"]["chunk"] + token = _as_text(getattr(chunk, "content", "")) + if token: + yield "token", token async def run_floating_stream( user_id: str, @@ -542,35 +531,24 @@ async def run_floating_stream( domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] yield "floating_domain", domain - queue: asyncio.Queue[str] = asyncio.Queue() + state_input = { + "user_id": user_id, + "user_message": message, + "context": context, + "memory_context": context, + "plan": [t.model_dump() for t in plan.tasks], + "floating_domain": domain, + "worker_results": [], + } - async def _on_token(token: str) -> None: - await queue.put(token) - - task = asyncio.create_task( - FLOATING_GRAPH.ainvoke( - { - "user_id": user_id, - "user_message": message, - "context": context, - "memory_context": context, - "plan": [t.model_dump() for t in plan.tasks], - "floating_domain": domain, - "worker_results": [], - "stream_callback": _on_token, - } - ) - ) - - emitted = False - while not task.done() or not queue.empty(): - try: - token = await asyncio.wait_for(queue.get(), timeout=0.15) - emitted = True - yield "token", token - except asyncio.TimeoutError: - continue - - final_state = await task - if not emitted and final_state.get("final_response"): - yield "token", str(final_state["final_response"]) + async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"): + kind = event["event"] + + if kind == "on_chat_model_stream": + node_name = event.get("metadata", {}).get("langgraph_node") + + if node_name == "synthesizer": + chunk = event["data"]["chunk"] + token = _as_text(getattr(chunk, "content", "")) + if token: + yield "token", token