step-2: add agent streaming and tool result capture (agent_registry.py)
- ChatAgent.__init__: adds tool_results: list[dict] = [] - _tool_loop: wraps execution in a result collector; populates self.tool_results with raw execute_on_client dicts after each run - _tool_loop_stream: streaming variant — uses ainvoke for tool-call iterations, llm.astream() for the final answer; same result capture - ws_context.py: adds _tool_result_collector ContextVar + set/clear helpers; execute_on_client appends to collector when set Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -34,6 +35,11 @@ class BaseAgent(ABC):
|
||||
class ChatAgent(BaseAgent):
|
||||
"""Base class for LLM-powered chat agents."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
||||
self.tool_results: list[dict] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
"""Process a user query and return a text response."""
|
||||
@@ -55,34 +61,98 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||
requesting tool calls or *max_iter* is reached, and returns the
|
||||
final text response.
|
||||
final text response. Captures raw execute_on_client results in
|
||||
``self.tool_results``.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
if not response.tool_calls:
|
||||
return str(response.content)
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
if not response.tool_calls:
|
||||
return str(response.content)
|
||||
|
||||
# Exhausted iterations — ask model for a final answer without tools
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — ask model for a final answer without tools
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
self.tool_results = collector
|
||||
|
||||
async def _tool_loop_stream(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming variant of ``_tool_loop``.
|
||||
|
||||
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
||||
tool calls). For the final response — when the model produces no further
|
||||
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
||||
Captures raw execute_on_client results in ``self.tool_results``.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
|
||||
if not response.tool_calls:
|
||||
# Stream the final answer — don't keep the ainvoke result.
|
||||
async for chunk in llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield str(chunk.content)
|
||||
return
|
||||
|
||||
messages.append(response)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — stream a final answer without tools
|
||||
async for chunk in llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield str(chunk.content)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
self.tool_results = collector
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
|
||||
@@ -17,6 +17,22 @@ _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = Cont
|
||||
"_client_executor"
|
||||
)
|
||||
|
||||
# Optional collector that captures raw execute_on_client results.
|
||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||
"""Register *lst* as the collector for this async context."""
|
||||
_tool_result_collector.set(lst)
|
||||
|
||||
|
||||
def clear_tool_result_collector() -> None:
|
||||
"""Clear the collector (best-effort)."""
|
||||
_tool_result_collector.set(None)
|
||||
|
||||
|
||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||
@@ -65,4 +81,8 @@ async def execute_on_client(
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
return await callback(payload)
|
||||
result = await callback(payload)
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append(result)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user