refactor: use native LangGraph streaming and enforce structured summary on workers

This commit is contained in:
2026-03-12 22:50:32 +01:00
parent fe085a7951
commit d667e43c73

View File

@@ -36,6 +36,10 @@ class WorkerTask(BaseModel):
instruction: str
class WorkerSummary(BaseModel):
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
class WorkerPlan(BaseModel):
tasks: list[WorkerTask] = Field(default_factory=list)
floating_domain: FloatingDomain | None = None
@@ -58,7 +62,6 @@ class OrchestratorState(TypedDict, total=False):
task: dict[str, Any]
worker_results: list[WorkerResult]
final_response: str
stream_callback: Callable[[str], Awaitable[None]] | None
class GraphState(OrchestratorState):
@@ -276,8 +279,13 @@ async def _run_tool_loop(
tool_output = await tool_fn.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content), collected
structured_llm = llm.with_structured_output(WorkerSummary)
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
final_summary = await structured_llm.ainvoke(messages)
if isinstance(final_summary, WorkerSummary):
return final_summary.summary, collected
return str(final_summary), collected
finally:
clear_tool_result_collector()
@@ -336,7 +344,6 @@ async def _stream_with_memory_tool(
user_id: str,
system_prompt: str,
user_prompt: str,
stream_callback: Callable[[str], Awaitable[None]] | None,
) -> str:
@tool
async def update_core_memory(key: str, value: str) -> str:
@@ -375,8 +382,6 @@ async def _stream_with_memory_tool(
if not token:
continue
chunks.append(token)
if stream_callback is not None:
await stream_callback(token)
return "".join(chunks)
@@ -390,7 +395,6 @@ def _synthesizer_node(floating: bool):
user_id=str(state.get("user_id", "")),
system_prompt=system_prompt,
user_prompt=prompt,
stream_callback=state.get("stream_callback"),
)
return {"final_response": final_response}
@@ -471,12 +475,10 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
"context": context,
"memory_context": context,
"worker_results": [],
"stream_callback": None,
}
)
return str(state.get("final_response", ""))
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
plan = await _plan_with_llm(message, context, floating=True)
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
@@ -490,7 +492,6 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
"plan": [task.model_dump() for task in plan.tasks],
"floating_domain": domain,
"worker_results": [],
"stream_callback": None,
}
)
return str(state.get("final_response", "")), str(domain)
@@ -501,37 +502,25 @@ async def run_home_stream(
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
queue: asyncio.Queue[str] = asyncio.Queue()
state_input = {
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"worker_results": [],
}
async def _on_token(token: str) -> None:
await queue.put(token)
async for event in HOME_GRAPH.astream_events(state_input, version="v2"):
kind = event["event"]
task = asyncio.create_task(
HOME_GRAPH.ainvoke(
{
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"worker_results": [],
"stream_callback": _on_token,
}
)
)
emitted = False
while not task.done() or not queue.empty():
try:
token = await asyncio.wait_for(queue.get(), timeout=0.15)
emitted = True
yield "token", token
except asyncio.TimeoutError:
continue
final_state = await task
if not emitted and final_state.get("final_response"):
yield "token", str(final_state["final_response"])
if kind == "on_chat_model_stream":
node_name = event.get("metadata", {}).get("langgraph_node")
if node_name == "synthesizer":
chunk = event["data"]["chunk"]
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
async def run_floating_stream(
user_id: str,
@@ -542,35 +531,24 @@ async def run_floating_stream(
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
yield "floating_domain", domain
queue: asyncio.Queue[str] = asyncio.Queue()
state_input = {
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"plan": [t.model_dump() for t in plan.tasks],
"floating_domain": domain,
"worker_results": [],
}
async def _on_token(token: str) -> None:
await queue.put(token)
async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"):
kind = event["event"]
task = asyncio.create_task(
FLOATING_GRAPH.ainvoke(
{
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"plan": [t.model_dump() for t in plan.tasks],
"floating_domain": domain,
"worker_results": [],
"stream_callback": _on_token,
}
)
)
if kind == "on_chat_model_stream":
node_name = event.get("metadata", {}).get("langgraph_node")
emitted = False
while not task.done() or not queue.empty():
try:
token = await asyncio.wait_for(queue.get(), timeout=0.15)
emitted = True
yield "token", token
except asyncio.TimeoutError:
continue
final_state = await task
if not emitted and final_state.get("final_response"):
yield "token", str(final_state["final_response"])
if node_name == "synthesizer":
chunk = event["data"]["chunk"]
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token