step-2: add agent streaming and tool result capture (agent_registry.py)

- ChatAgent.__init__: adds tool_results: list[dict] = []
- _tool_loop: wraps execution in a result collector; populates
  self.tool_results with raw execute_on_client dicts after each run
- _tool_loop_stream: streaming variant — uses ainvoke for tool-call
  iterations, llm.astream() for the final answer; same result capture
- ws_context.py: adds _tool_result_collector ContextVar +
  set/clear helpers; execute_on_client appends to collector when set

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-08 21:37:15 +01:00
parent 7efaeba283
commit 7cb384fa63
4 changed files with 543 additions and 24 deletions

View File

@@ -7,6 +7,18 @@
--- ---
## General Rules
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
- Old functions/methods that are superseded by new ones
- Deprecated imports or modules
- Dead code paths
- Old test files no longer needed
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
---
## Decisions Log ## Decisions Log
| Topic | Decision | | Topic | Decision |
@@ -74,7 +86,7 @@ pytest tests/test_agent_streaming.py
``` ```
**Status**: **Status**:
- [ ] Step 2 complete - [x] Step 2 complete
**Commit**: After tests pass, commit with: **Commit**: After tests pass, commit with:
``` ```
@@ -222,8 +234,9 @@ git commit -m "step-4: add output formatting layer (output_formatter.py)"
- Each request gets a `request_id` (UUID) for frame correlation. - Each request gets a `request_id` (UUID) for frame correlation.
- Concurrent requests from same client are supported (each runs as an async task). - Concurrent requests from same client are supported (each runs as an async task).
- `app/api/routes/chat.py`: - `app/api/routes/chat.py`:
- Remove `chat_stream` WS endpoint. - Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
- Keep `POST /chat` endpoint unchanged (REST fallback). - Keep `POST /chat` endpoint unchanged (REST fallback).
- Clean up any unused imports.
- `app/main.py`: - `app/main.py`:
- No change needed (device_ws router already registered). - No change needed (device_ws router already registered).

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any from typing import Any
@@ -34,6 +35,11 @@ class BaseAgent(ABC):
class ChatAgent(BaseAgent): class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents.""" """Base class for LLM-powered chat agents."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
self.tool_results: list[dict] = []
@abstractmethod @abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str: async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response.""" """Process a user query and return a text response."""
@@ -55,34 +61,98 @@ class ChatAgent(BaseAgent):
Binds *tools* to *llm*, invokes iteratively until the model stops Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the requesting tool calls or *max_iter* is reached, and returns the
final text response. final text response. Captures raw execute_on_client results in
``self.tool_results``.
""" """
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
llm_with_tools = llm.bind_tools(tools) if tools else llm from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
for _ in range(max_iter): collector: list[dict] = []
response: AIMessage = await llm_with_tools.ainvoke(messages) set_tool_result_collector(collector)
messages.append(response) try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
if not response.tool_calls: for _ in range(max_iter):
return str(response.content) response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
# Execute each requested tool call if not response.tool_calls:
tool_map = {t.name: t for t in tools} return str(response.content)
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools # Execute each requested tool call
response = await llm.ainvoke(messages) tool_map = {t.name: t for t in tools}
return str(response.content) for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages)
return str(response.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
async def _tool_loop_stream(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> AsyncGenerator[str, None]:
"""Streaming variant of ``_tool_loop``.
Behaves identically for tool-calling iterations (uses ainvoke to parse
tool calls). For the final response — when the model produces no further
tool calls — switches to ``llm.astream()`` and yields text tokens.
Captures raw execute_on_client results in ``self.tool_results``.
"""
from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
if not response.tool_calls:
# Stream the final answer — don't keep the ainvoke result.
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
return
messages.append(response)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — stream a final answer without tools
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
class AgentRegistry: class AgentRegistry:

View File

@@ -17,6 +17,22 @@ _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = Cont
"_client_executor" "_client_executor"
) )
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_tool_result_collector(lst: list[dict]) -> None:
"""Register *lst* as the collector for this async context."""
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
"""Clear the collector (best-effort)."""
_tool_result_collector.set(None)
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None: def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
"""Bind *fn* as the executor for the current async context (task/coroutine).""" """Bind *fn* as the executor for the current async context (task/coroutine)."""
@@ -65,4 +81,8 @@ async def execute_on_client(
if limit is not None: if limit is not None:
payload["limit"] = limit payload["limit"] = limit
return await callback(payload) result = await callback(payload)
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append(result)
return result

View File

@@ -0,0 +1,416 @@
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.core.agent_registry import ChatAgent, registry
# ── Minimal concrete agent for testing ───────────────────────────────
class _EchoAgent(ChatAgent):
def get_name(self) -> str:
return "_echo"
def get_description(self) -> str:
return "Echo agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return query
# ── Helpers ───────────────────────────────────────────────────────────
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
msg = AIMessage(content=content)
if tool_calls:
msg.tool_calls = tool_calls
else:
msg.tool_calls = []
return msg
def _make_tool(name: str, return_value: Any) -> MagicMock:
t = MagicMock()
t.name = name
t.ainvoke = AsyncMock(return_value=return_value)
return t
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
chunks = []
for tok in tokens:
c = MagicMock()
c.content = tok
chunks.append(c)
return chunks
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
tokens: list[str] = []
async for tok in agent._tool_loop_stream(llm, messages, tools):
tokens.append(tok)
return tokens
# ── tool_results initialised ─────────────────────────────────────────
def test_tool_results_init():
agent = _EchoAgent()
assert agent.tool_results == []
# ── _tool_loop: no tool calls ────────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_no_tools():
agent = _EchoAgent()
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert result == "Hello!"
assert agent.tool_results == []
# ── _tool_loop: with one tool call + result capture ──────────────────
@pytest.mark.asyncio
async def test_tool_loop_captures_tool_results():
agent = _EchoAgent()
# Mock execute_on_client to return structured data via the tool
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
async def fake_executor(payload: dict) -> dict:
return raw_result
# AIMessage with a tool call, then a final answer
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Here are your tasks.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
llm.ainvoke = AsyncMock(return_value=final_msg)
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
# Patch the tool to actually call execute_on_client
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
rows = res.get("rows", [])
return "\n".join(r["title"] for r in rows)
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
result = await agent._tool_loop(
llm, [HumanMessage(content="list my tasks")], [mock_tool]
)
finally:
clear_client_executor()
assert result == "Here are your tasks."
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop: tool_results reset on each call ──────────────────────
@pytest.mark.asyncio
async def test_tool_loop_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert agent.tool_results == []
# ── _tool_loop: unknown tool name ────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_unknown_tool():
agent = _EchoAgent()
# No known tools — model still calls a non-existent one; loop handles gracefully
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Handled.")
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
assert result == "Handled."
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
fallback = _make_ai_message("Fallback.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
# Returns tool_call_msg on every iteration
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
llm.ainvoke = AsyncMock(return_value=fallback)
mock_tool = _make_tool("t", "ok")
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
assert result == "Fallback."
assert llm_with_tools.ainvoke.call_count == 2
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_no_tools_yields_tokens():
agent = _EchoAgent()
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
no_tool_msg = _make_ai_message("irrelevant")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["Hello", " ", "world"]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
assert tokens == ["Hello", " ", "world"]
assert agent.tool_results == []
# ── _tool_loop_stream: one tool call then streaming final ─────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_with_tool_call():
agent = _EchoAgent()
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
async def fake_executor(payload: dict) -> dict:
return raw_result
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
)
# After tools run, ainvoke returns no more tool calls
no_more_tools_msg = _make_ai_message("Task found.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
for tok in ["Task", " ", "found."]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
return res.get("row", {}).get("title", "")
mock_tool = _make_tool("get_task", "Deploy")
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
)
finally:
clear_client_executor()
assert tokens == ["Task", " ", "found."]
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop_stream: tool_results reset on each call ───────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"old": True}]
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
c = MagicMock()
c.content = "ok"
yield c
llm.astream = fake_astream
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert agent.tool_results == []
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_skips_empty_chunks():
agent = _EchoAgent()
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["", "hello", "", " world", ""]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert tokens == ["hello", " world"]
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
@pytest.mark.asyncio
async def test_tool_loop_stream_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
async def fake_astream(msgs):
c = MagicMock()
c.content = "fallback"
yield c
llm.astream = fake_astream
mock_tool = _make_tool("t", "ok")
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [mock_tool],
)
assert tokens == ["fallback"]
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
# ── _tool_loop_stream: multiple tool results captured ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_multiple_tool_results():
agent = _EchoAgent()
call_results = [
{"rows": [{"id": "t-1"}]},
{"rows": [{"id": "t-2"}]},
]
call_iter = iter(call_results)
async def fake_executor(payload: dict) -> dict:
return next(call_iter)
# Two tool calls in one iteration
tool_call_msg = _make_ai_message(
tool_calls=[
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
]
)
no_more_tools_msg = _make_ai_message("Done.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
c = MagicMock()
c.content = "Done."
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
return str(res)
tool_a = _make_tool("tool_a", "")
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
tool_b = _make_tool("tool_b", "")
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
)
finally:
clear_client_executor()
assert tokens == ["Done."]
assert len(agent.tool_results) == 2
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}