refactor: use native LangGraph streaming and enforce structured summary on workers
This commit is contained in:
@@ -36,6 +36,10 @@ class WorkerTask(BaseModel):
|
|||||||
instruction: str
|
instruction: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerSummary(BaseModel):
|
||||||
|
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
|
||||||
|
|
||||||
|
|
||||||
class WorkerPlan(BaseModel):
|
class WorkerPlan(BaseModel):
|
||||||
tasks: list[WorkerTask] = Field(default_factory=list)
|
tasks: list[WorkerTask] = Field(default_factory=list)
|
||||||
floating_domain: FloatingDomain | None = None
|
floating_domain: FloatingDomain | None = None
|
||||||
@@ -58,7 +62,6 @@ class OrchestratorState(TypedDict, total=False):
|
|||||||
task: dict[str, Any]
|
task: dict[str, Any]
|
||||||
worker_results: list[WorkerResult]
|
worker_results: list[WorkerResult]
|
||||||
final_response: str
|
final_response: str
|
||||||
stream_callback: Callable[[str], Awaitable[None]] | None
|
|
||||||
|
|
||||||
|
|
||||||
class GraphState(OrchestratorState):
|
class GraphState(OrchestratorState):
|
||||||
@@ -276,8 +279,13 @@ async def _run_tool_loop(
|
|||||||
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
final = await llm.ainvoke(messages)
|
structured_llm = llm.with_structured_output(WorkerSummary)
|
||||||
return _as_text(final.content), collected
|
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
|
||||||
|
final_summary = await structured_llm.ainvoke(messages)
|
||||||
|
|
||||||
|
if isinstance(final_summary, WorkerSummary):
|
||||||
|
return final_summary.summary, collected
|
||||||
|
return str(final_summary), collected
|
||||||
finally:
|
finally:
|
||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
|
|
||||||
@@ -336,7 +344,6 @@ async def _stream_with_memory_tool(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
stream_callback: Callable[[str], Awaitable[None]] | None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
@tool
|
@tool
|
||||||
async def update_core_memory(key: str, value: str) -> str:
|
async def update_core_memory(key: str, value: str) -> str:
|
||||||
@@ -375,8 +382,6 @@ async def _stream_with_memory_tool(
|
|||||||
if not token:
|
if not token:
|
||||||
continue
|
continue
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
if stream_callback is not None:
|
|
||||||
await stream_callback(token)
|
|
||||||
|
|
||||||
return "".join(chunks)
|
return "".join(chunks)
|
||||||
|
|
||||||
@@ -390,7 +395,6 @@ def _synthesizer_node(floating: bool):
|
|||||||
user_id=str(state.get("user_id", "")),
|
user_id=str(state.get("user_id", "")),
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
user_prompt=prompt,
|
user_prompt=prompt,
|
||||||
stream_callback=state.get("stream_callback"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"final_response": final_response}
|
return {"final_response": final_response}
|
||||||
@@ -471,12 +475,10 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
"context": context,
|
"context": context,
|
||||||
"memory_context": context,
|
"memory_context": context,
|
||||||
"worker_results": [],
|
"worker_results": [],
|
||||||
"stream_callback": None,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return str(state.get("final_response", ""))
|
return str(state.get("final_response", ""))
|
||||||
|
|
||||||
|
|
||||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
||||||
plan = await _plan_with_llm(message, context, floating=True)
|
plan = await _plan_with_llm(message, context, floating=True)
|
||||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||||
@@ -490,7 +492,6 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
"plan": [task.model_dump() for task in plan.tasks],
|
"plan": [task.model_dump() for task in plan.tasks],
|
||||||
"floating_domain": domain,
|
"floating_domain": domain,
|
||||||
"worker_results": [],
|
"worker_results": [],
|
||||||
"stream_callback": None,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return str(state.get("final_response", "")), str(domain)
|
return str(state.get("final_response", "")), str(domain)
|
||||||
@@ -501,37 +502,25 @@ async def run_home_stream(
|
|||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
queue: asyncio.Queue[str] = asyncio.Queue()
|
state_input = {
|
||||||
|
|
||||||
async def _on_token(token: str) -> None:
|
|
||||||
await queue.put(token)
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
HOME_GRAPH.ainvoke(
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"user_message": message,
|
"user_message": message,
|
||||||
"context": context,
|
"context": context,
|
||||||
"memory_context": context,
|
"memory_context": context,
|
||||||
"worker_results": [],
|
"worker_results": [],
|
||||||
"stream_callback": _on_token,
|
|
||||||
}
|
}
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
emitted = False
|
async for event in HOME_GRAPH.astream_events(state_input, version="v2"):
|
||||||
while not task.done() or not queue.empty():
|
kind = event["event"]
|
||||||
try:
|
|
||||||
token = await asyncio.wait_for(queue.get(), timeout=0.15)
|
if kind == "on_chat_model_stream":
|
||||||
emitted = True
|
node_name = event.get("metadata", {}).get("langgraph_node")
|
||||||
|
|
||||||
|
if node_name == "synthesizer":
|
||||||
|
chunk = event["data"]["chunk"]
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
yield "token", token
|
yield "token", token
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
final_state = await task
|
|
||||||
if not emitted and final_state.get("final_response"):
|
|
||||||
yield "token", str(final_state["final_response"])
|
|
||||||
|
|
||||||
|
|
||||||
async def run_floating_stream(
|
async def run_floating_stream(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -542,14 +531,7 @@ async def run_floating_stream(
|
|||||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||||
yield "floating_domain", domain
|
yield "floating_domain", domain
|
||||||
|
|
||||||
queue: asyncio.Queue[str] = asyncio.Queue()
|
state_input = {
|
||||||
|
|
||||||
async def _on_token(token: str) -> None:
|
|
||||||
await queue.put(token)
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
FLOATING_GRAPH.ainvoke(
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"user_message": message,
|
"user_message": message,
|
||||||
"context": context,
|
"context": context,
|
||||||
@@ -557,20 +539,16 @@ async def run_floating_stream(
|
|||||||
"plan": [t.model_dump() for t in plan.tasks],
|
"plan": [t.model_dump() for t in plan.tasks],
|
||||||
"floating_domain": domain,
|
"floating_domain": domain,
|
||||||
"worker_results": [],
|
"worker_results": [],
|
||||||
"stream_callback": _on_token,
|
|
||||||
}
|
}
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
emitted = False
|
async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"):
|
||||||
while not task.done() or not queue.empty():
|
kind = event["event"]
|
||||||
try:
|
|
||||||
token = await asyncio.wait_for(queue.get(), timeout=0.15)
|
if kind == "on_chat_model_stream":
|
||||||
emitted = True
|
node_name = event.get("metadata", {}).get("langgraph_node")
|
||||||
|
|
||||||
|
if node_name == "synthesizer":
|
||||||
|
chunk = event["data"]["chunk"]
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
yield "token", token
|
yield "token", token
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
final_state = await task
|
|
||||||
if not emitted and final_state.get("final_response"):
|
|
||||||
yield "token", str(final_state["final_response"])
|
|
||||||
|
|||||||
Reference in New Issue
Block a user