- ChatAgent.__init__: adds tool_results: list[dict] = [] - _tool_loop: wraps execution in a result collector; populates self.tool_results with raw execute_on_client dicts after each run - _tool_loop_stream: streaming variant — uses ainvoke for tool-call iterations, llm.astream() for the final answer; same result capture - ws_context.py: adds _tool_result_collector ContextVar + set/clear helpers; execute_on_client appends to collector when set Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
417 lines
13 KiB
Python
417 lines
13 KiB
Python
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
|
|
from app.core.agent_registry import ChatAgent, registry
|
|
|
|
|
|
# ── Minimal concrete agent for testing ───────────────────────────────
|
|
|
|
|
|
class _EchoAgent(ChatAgent):
|
|
def get_name(self) -> str:
|
|
return "_echo"
|
|
|
|
def get_description(self) -> str:
|
|
return "Echo agent for tests"
|
|
|
|
def get_tools(self) -> list[Any]:
|
|
return []
|
|
|
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
return query
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────
|
|
|
|
|
|
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
|
|
msg = AIMessage(content=content)
|
|
if tool_calls:
|
|
msg.tool_calls = tool_calls
|
|
else:
|
|
msg.tool_calls = []
|
|
return msg
|
|
|
|
|
|
def _make_tool(name: str, return_value: Any) -> MagicMock:
|
|
t = MagicMock()
|
|
t.name = name
|
|
t.ainvoke = AsyncMock(return_value=return_value)
|
|
return t
|
|
|
|
|
|
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
|
|
chunks = []
|
|
for tok in tokens:
|
|
c = MagicMock()
|
|
c.content = tok
|
|
chunks.append(c)
|
|
return chunks
|
|
|
|
|
|
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
|
|
tokens: list[str] = []
|
|
async for tok in agent._tool_loop_stream(llm, messages, tools):
|
|
tokens.append(tok)
|
|
return tokens
|
|
|
|
|
|
# ── tool_results initialised ─────────────────────────────────────────
|
|
|
|
|
|
def test_tool_results_init():
|
|
agent = _EchoAgent()
|
|
assert agent.tool_results == []
|
|
|
|
|
|
# ── _tool_loop: no tool calls ────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_no_tools():
|
|
agent = _EchoAgent()
|
|
llm = AsyncMock()
|
|
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
|
|
|
|
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
|
assert result == "Hello!"
|
|
assert agent.tool_results == []
|
|
|
|
|
|
# ── _tool_loop: with one tool call + result capture ──────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_captures_tool_results():
|
|
agent = _EchoAgent()
|
|
|
|
# Mock execute_on_client to return structured data via the tool
|
|
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
|
|
|
|
async def fake_executor(payload: dict) -> dict:
|
|
return raw_result
|
|
|
|
# AIMessage with a tool call, then a final answer
|
|
tool_call_msg = _make_ai_message(
|
|
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
|
|
)
|
|
final_msg = _make_ai_message("Here are your tasks.")
|
|
|
|
llm = MagicMock()
|
|
llm_with_tools = MagicMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
|
|
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
|
|
|
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
set_client_executor(fake_executor)
|
|
try:
|
|
# Patch the tool to actually call execute_on_client
|
|
async def tool_side_effect(args: dict) -> str:
|
|
from app.core.ws_context import execute_on_client
|
|
res = await execute_on_client(action="select", table="tasks")
|
|
rows = res.get("rows", [])
|
|
return "\n".join(r["title"] for r in rows)
|
|
|
|
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
|
|
result = await agent._tool_loop(
|
|
llm, [HumanMessage(content="list my tasks")], [mock_tool]
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
assert result == "Here are your tasks."
|
|
assert len(agent.tool_results) == 1
|
|
assert agent.tool_results[0] == raw_result
|
|
|
|
|
|
# ── _tool_loop: tool_results reset on each call ──────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_resets_tool_results():
|
|
agent = _EchoAgent()
|
|
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
|
|
|
|
llm = AsyncMock()
|
|
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
|
|
|
|
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
|
assert agent.tool_results == []
|
|
|
|
|
|
# ── _tool_loop: unknown tool name ────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_unknown_tool():
|
|
agent = _EchoAgent()
|
|
|
|
# No known tools — model still calls a non-existent one; loop handles gracefully
|
|
tool_call_msg = _make_ai_message(
|
|
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
)
|
|
final_msg = _make_ai_message("Handled.")
|
|
|
|
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
|
|
llm = MagicMock()
|
|
llm_with_tools = MagicMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
|
|
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
|
|
assert result == "Handled."
|
|
|
|
|
|
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_max_iter():
|
|
agent = _EchoAgent()
|
|
|
|
always_tool = _make_ai_message(
|
|
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
)
|
|
fallback = _make_ai_message("Fallback.")
|
|
|
|
llm = MagicMock()
|
|
llm_with_tools = MagicMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
# Returns tool_call_msg on every iteration
|
|
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
|
llm.ainvoke = AsyncMock(return_value=fallback)
|
|
|
|
mock_tool = _make_tool("t", "ok")
|
|
|
|
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
|
|
assert result == "Fallback."
|
|
assert llm_with_tools.ainvoke.call_count == 2
|
|
|
|
|
|
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_stream_no_tools_yields_tokens():
|
|
agent = _EchoAgent()
|
|
|
|
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
|
|
no_tool_msg = _make_ai_message("irrelevant")
|
|
llm = AsyncMock()
|
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
|
|
async def fake_astream(msgs):
|
|
for tok in ["Hello", " ", "world"]:
|
|
c = MagicMock()
|
|
c.content = tok
|
|
yield c
|
|
|
|
llm.astream = fake_astream
|
|
|
|
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
|
|
assert tokens == ["Hello", " ", "world"]
|
|
assert agent.tool_results == []
|
|
|
|
|
|
# ── _tool_loop_stream: one tool call then streaming final ─────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_stream_with_tool_call():
|
|
agent = _EchoAgent()
|
|
|
|
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
|
|
|
|
async def fake_executor(payload: dict) -> dict:
|
|
return raw_result
|
|
|
|
tool_call_msg = _make_ai_message(
|
|
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
|
|
)
|
|
# After tools run, ainvoke returns no more tool calls
|
|
no_more_tools_msg = _make_ai_message("Task found.")
|
|
|
|
llm = MagicMock()
|
|
llm_with_tools = MagicMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
|
|
|
async def fake_astream(msgs):
|
|
for tok in ["Task", " ", "found."]:
|
|
c = MagicMock()
|
|
c.content = tok
|
|
yield c
|
|
|
|
llm.astream = fake_astream
|
|
|
|
async def tool_side_effect(args: dict) -> str:
|
|
from app.core.ws_context import execute_on_client
|
|
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
|
|
return res.get("row", {}).get("title", "")
|
|
|
|
mock_tool = _make_tool("get_task", "Deploy")
|
|
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
|
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
set_client_executor(fake_executor)
|
|
try:
|
|
tokens = await _collect_stream(
|
|
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
assert tokens == ["Task", " ", "found."]
|
|
assert len(agent.tool_results) == 1
|
|
assert agent.tool_results[0] == raw_result
|
|
|
|
|
|
# ── _tool_loop_stream: tool_results reset on each call ───────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_stream_resets_tool_results():
|
|
agent = _EchoAgent()
|
|
agent.tool_results = [{"old": True}]
|
|
|
|
no_tool_msg = _make_ai_message("")
|
|
llm = AsyncMock()
|
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
|
|
async def fake_astream(msgs):
|
|
c = MagicMock()
|
|
c.content = "ok"
|
|
yield c
|
|
|
|
llm.astream = fake_astream
|
|
|
|
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
|
assert agent.tool_results == []
|
|
|
|
|
|
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_stream_skips_empty_chunks():
|
|
agent = _EchoAgent()
|
|
no_tool_msg = _make_ai_message("")
|
|
|
|
llm = AsyncMock()
|
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
|
|
|
async def fake_astream(msgs):
|
|
for tok in ["", "hello", "", " world", ""]:
|
|
c = MagicMock()
|
|
c.content = tok
|
|
yield c
|
|
|
|
llm.astream = fake_astream
|
|
|
|
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
|
assert tokens == ["hello", " world"]
|
|
|
|
|
|
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_stream_max_iter():
|
|
agent = _EchoAgent()
|
|
|
|
always_tool = _make_ai_message(
|
|
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
|
)
|
|
|
|
llm = MagicMock()
|
|
llm_with_tools = MagicMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
|
|
|
async def fake_astream(msgs):
|
|
c = MagicMock()
|
|
c.content = "fallback"
|
|
yield c
|
|
|
|
llm.astream = fake_astream
|
|
mock_tool = _make_tool("t", "ok")
|
|
|
|
tokens = await _collect_stream(
|
|
agent, llm, [HumanMessage(content="x")], [mock_tool],
|
|
)
|
|
assert tokens == ["fallback"]
|
|
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
|
|
|
|
|
|
# ── _tool_loop_stream: multiple tool results captured ────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_loop_stream_multiple_tool_results():
|
|
agent = _EchoAgent()
|
|
|
|
call_results = [
|
|
{"rows": [{"id": "t-1"}]},
|
|
{"rows": [{"id": "t-2"}]},
|
|
]
|
|
call_iter = iter(call_results)
|
|
|
|
async def fake_executor(payload: dict) -> dict:
|
|
return next(call_iter)
|
|
|
|
# Two tool calls in one iteration
|
|
tool_call_msg = _make_ai_message(
|
|
tool_calls=[
|
|
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
|
|
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
|
|
]
|
|
)
|
|
no_more_tools_msg = _make_ai_message("Done.")
|
|
|
|
llm = MagicMock()
|
|
llm_with_tools = MagicMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
|
|
|
async def fake_astream(msgs):
|
|
c = MagicMock()
|
|
c.content = "Done."
|
|
yield c
|
|
|
|
llm.astream = fake_astream
|
|
|
|
async def tool_side_effect(args: dict) -> str:
|
|
from app.core.ws_context import execute_on_client
|
|
res = await execute_on_client(action="select", table="tasks")
|
|
return str(res)
|
|
|
|
tool_a = _make_tool("tool_a", "")
|
|
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
tool_b = _make_tool("tool_b", "")
|
|
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
|
|
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
|
set_client_executor(fake_executor)
|
|
try:
|
|
tokens = await _collect_stream(
|
|
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
|
|
)
|
|
finally:
|
|
clear_client_executor()
|
|
|
|
assert tokens == ["Done."]
|
|
assert len(agent.tool_results) == 2
|
|
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
|
|
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
|