"""Tests for ChatAgent streaming and tool result capture (Step 2).""" from __future__ import annotations import pytest from unittest.mock import AsyncMock, MagicMock, patch from typing import Any from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from app.core.agent_registry import ChatAgent, registry # ── Minimal concrete agent for testing ─────────────────────────────── class _EchoAgent(ChatAgent): def get_name(self) -> str: return "_echo" def get_description(self) -> str: return "Echo agent for tests" def get_tools(self) -> list[Any]: return [] async def handle(self, query: str, context: dict[str, Any]) -> str: return query # ── Helpers ─────────────────────────────────────────────────────────── def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage: msg = AIMessage(content=content) if tool_calls: msg.tool_calls = tool_calls else: msg.tool_calls = [] return msg def _make_tool(name: str, return_value: Any) -> MagicMock: t = MagicMock() t.name = name t.ainvoke = AsyncMock(return_value=return_value) return t def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]: chunks = [] for tok in tokens: c = MagicMock() c.content = tok chunks.append(c) return chunks async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]: tokens: list[str] = [] async for tok in agent._tool_loop_stream(llm, messages, tools): tokens.append(tok) return tokens # ── tool_results initialised ───────────────────────────────────────── def test_tool_results_init(): agent = _EchoAgent() assert agent.tool_results == [] # ── _tool_loop: no tool calls ──────────────────────────────────────── @pytest.mark.asyncio async def test_tool_loop_no_tools(): agent = _EchoAgent() llm = AsyncMock() llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!")) result = await agent._tool_loop(llm, [HumanMessage(content="hi")], []) assert result == "Hello!" assert agent.tool_results == [] # ── _tool_loop: with one tool call + result capture ────────────────── @pytest.mark.asyncio async def test_tool_loop_captures_tool_results(): agent = _EchoAgent() # Mock execute_on_client to return structured data via the tool raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]} async def fake_executor(payload: dict) -> dict: return raw_result # AIMessage with a tool call, then a final answer tool_call_msg = _make_ai_message( tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}] ) final_msg = _make_ai_message("Here are your tasks.") llm = MagicMock() llm_with_tools = MagicMock() llm.bind_tools = MagicMock(return_value=llm_with_tools) llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) llm.ainvoke = AsyncMock(return_value=final_msg) mock_tool = _make_tool("list_tasks", "- Fix bug (todo)") from app.core.ws_context import set_client_executor, clear_client_executor set_client_executor(fake_executor) try: # Patch the tool to actually call execute_on_client async def tool_side_effect(args: dict) -> str: from app.core.ws_context import execute_on_client res = await execute_on_client(action="select", table="tasks") rows = res.get("rows", []) return "\n".join(r["title"] for r in rows) mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect) result = await agent._tool_loop( llm, [HumanMessage(content="list my tasks")], [mock_tool] ) finally: clear_client_executor() assert result == "Here are your tasks." assert len(agent.tool_results) == 1 assert agent.tool_results[0] == raw_result # ── _tool_loop: tool_results reset on each call ────────────────────── @pytest.mark.asyncio async def test_tool_loop_resets_tool_results(): agent = _EchoAgent() agent.tool_results = [{"stale": True}] # pre-populated from a previous call llm = AsyncMock() llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done.")) await agent._tool_loop(llm, [HumanMessage(content="hi")], []) assert agent.tool_results == [] # ── _tool_loop: unknown tool name ──────────────────────────────────── @pytest.mark.asyncio async def test_tool_loop_unknown_tool(): agent = _EchoAgent() # No known tools — model still calls a non-existent one; loop handles gracefully tool_call_msg = _make_ai_message( tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}] ) final_msg = _make_ai_message("Handled.") mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent" llm = MagicMock() llm_with_tools = MagicMock() llm.bind_tools = MagicMock(return_value=llm_with_tools) llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool]) assert result == "Handled." # ── _tool_loop: max_iter exhaustion ────────────────────────────────── @pytest.mark.asyncio async def test_tool_loop_max_iter(): agent = _EchoAgent() always_tool = _make_ai_message( tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}] ) fallback = _make_ai_message("Fallback.") llm = MagicMock() llm_with_tools = MagicMock() llm.bind_tools = MagicMock(return_value=llm_with_tools) # Returns tool_call_msg on every iteration llm_with_tools.ainvoke = AsyncMock(return_value=always_tool) llm.ainvoke = AsyncMock(return_value=fallback) mock_tool = _make_tool("t", "ok") result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2) assert result == "Fallback." assert llm_with_tools.ainvoke.call_count == 2 # ── _tool_loop_stream: no tool calls — yields tokens ───────────────── @pytest.mark.asyncio async def test_tool_loop_stream_no_tools_yields_tokens(): agent = _EchoAgent() # No tools → llm used directly; ainvoke returns no tool calls → stream is used no_tool_msg = _make_ai_message("irrelevant") llm = AsyncMock() llm.ainvoke = AsyncMock(return_value=no_tool_msg) async def fake_astream(msgs): for tok in ["Hello", " ", "world"]: c = MagicMock() c.content = tok yield c llm.astream = fake_astream tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], []) assert tokens == ["Hello", " ", "world"] assert agent.tool_results == [] # ── _tool_loop_stream: one tool call then streaming final ───────────── @pytest.mark.asyncio async def test_tool_loop_stream_with_tool_call(): agent = _EchoAgent() raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}} async def fake_executor(payload: dict) -> dict: return raw_result tool_call_msg = _make_ai_message( tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}] ) # After tools run, ainvoke returns no more tool calls no_more_tools_msg = _make_ai_message("Task found.") llm = MagicMock() llm_with_tools = MagicMock() llm.bind_tools = MagicMock(return_value=llm_with_tools) llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg]) async def fake_astream(msgs): for tok in ["Task", " ", "found."]: c = MagicMock() c.content = tok yield c llm.astream = fake_astream async def tool_side_effect(args: dict) -> str: from app.core.ws_context import execute_on_client res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")}) return res.get("row", {}).get("title", "") mock_tool = _make_tool("get_task", "Deploy") mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect) from app.core.ws_context import set_client_executor, clear_client_executor set_client_executor(fake_executor) try: tokens = await _collect_stream( agent, llm, [HumanMessage(content="get task t-2")], [mock_tool] ) finally: clear_client_executor() assert tokens == ["Task", " ", "found."] assert len(agent.tool_results) == 1 assert agent.tool_results[0] == raw_result # ── _tool_loop_stream: tool_results reset on each call ─────────────── @pytest.mark.asyncio async def test_tool_loop_stream_resets_tool_results(): agent = _EchoAgent() agent.tool_results = [{"old": True}] no_tool_msg = _make_ai_message("") llm = AsyncMock() llm.ainvoke = AsyncMock(return_value=no_tool_msg) async def fake_astream(msgs): c = MagicMock() c.content = "ok" yield c llm.astream = fake_astream await _collect_stream(agent, llm, [HumanMessage(content="x")], []) assert agent.tool_results == [] # ── _tool_loop_stream: empty chunk content is skipped ──────────────── @pytest.mark.asyncio async def test_tool_loop_stream_skips_empty_chunks(): agent = _EchoAgent() no_tool_msg = _make_ai_message("") llm = AsyncMock() llm.ainvoke = AsyncMock(return_value=no_tool_msg) async def fake_astream(msgs): for tok in ["", "hello", "", " world", ""]: c = MagicMock() c.content = tok yield c llm.astream = fake_astream tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], []) assert tokens == ["hello", " world"] # ── _tool_loop_stream: max_iter exhaustion falls back to stream ─────── @pytest.mark.asyncio async def test_tool_loop_stream_max_iter(): agent = _EchoAgent() always_tool = _make_ai_message( tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}] ) llm = MagicMock() llm_with_tools = MagicMock() llm.bind_tools = MagicMock(return_value=llm_with_tools) llm_with_tools.ainvoke = AsyncMock(return_value=always_tool) async def fake_astream(msgs): c = MagicMock() c.content = "fallback" yield c llm.astream = fake_astream mock_tool = _make_tool("t", "ok") tokens = await _collect_stream( agent, llm, [HumanMessage(content="x")], [mock_tool], ) assert tokens == ["fallback"] assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter # ── _tool_loop_stream: multiple tool results captured ──────────────── @pytest.mark.asyncio async def test_tool_loop_stream_multiple_tool_results(): agent = _EchoAgent() call_results = [ {"rows": [{"id": "t-1"}]}, {"rows": [{"id": "t-2"}]}, ] call_iter = iter(call_results) async def fake_executor(payload: dict) -> dict: return next(call_iter) # Two tool calls in one iteration tool_call_msg = _make_ai_message( tool_calls=[ {"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"}, {"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"}, ] ) no_more_tools_msg = _make_ai_message("Done.") llm = MagicMock() llm_with_tools = MagicMock() llm.bind_tools = MagicMock(return_value=llm_with_tools) llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg]) async def fake_astream(msgs): c = MagicMock() c.content = "Done." yield c llm.astream = fake_astream async def tool_side_effect(args: dict) -> str: from app.core.ws_context import execute_on_client res = await execute_on_client(action="select", table="tasks") return str(res) tool_a = _make_tool("tool_a", "") tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect) tool_b = _make_tool("tool_b", "") tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect) from app.core.ws_context import set_client_executor, clear_client_executor set_client_executor(fake_executor) try: tokens = await _collect_stream( agent, llm, [HumanMessage(content="x")], [tool_a, tool_b] ) finally: clear_client_executor() assert tokens == ["Done."] assert len(agent.tool_results) == 2 assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]} assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}