step-2: add agent streaming and tool result capture (agent_registry.py)
- ChatAgent.__init__: adds tool_results: list[dict] = [] - _tool_loop: wraps execution in a result collector; populates self.tool_results with raw execute_on_client dicts after each run - _tool_loop_stream: streaming variant — uses ainvoke for tool-call iterations, llm.astream() for the final answer; same result capture - ws_context.py: adds _tool_result_collector ContextVar + set/clear helpers; execute_on_client appends to collector when set Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,18 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## General Rules
|
||||||
|
|
||||||
|
**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes:
|
||||||
|
- Old functions/methods that are superseded by new ones
|
||||||
|
- Deprecated imports or modules
|
||||||
|
- Dead code paths
|
||||||
|
- Old test files no longer needed
|
||||||
|
|
||||||
|
This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Decisions Log
|
## Decisions Log
|
||||||
|
|
||||||
| Topic | Decision |
|
| Topic | Decision |
|
||||||
@@ -74,7 +86,7 @@ pytest tests/test_agent_streaming.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Status**:
|
**Status**:
|
||||||
- [ ] Step 2 complete
|
- [x] Step 2 complete
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
**Commit**: After tests pass, commit with:
|
||||||
```
|
```
|
||||||
@@ -222,8 +234,9 @@ git commit -m "step-4: add output formatting layer (output_formatter.py)"
|
|||||||
- Each request gets a `request_id` (UUID) for frame correlation.
|
- Each request gets a `request_id` (UUID) for frame correlation.
|
||||||
- Concurrent requests from same client are supported (each runs as an async task).
|
- Concurrent requests from same client are supported (each runs as an async task).
|
||||||
- `app/api/routes/chat.py`:
|
- `app/api/routes/chat.py`:
|
||||||
- Remove `chat_stream` WS endpoint.
|
- Remove `chat_stream` WS endpoint and any related helper functions that were only used by it.
|
||||||
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
- Keep `POST /chat` endpoint unchanged (REST fallback).
|
||||||
|
- Clean up any unused imports.
|
||||||
- `app/main.py`:
|
- `app/main.py`:
|
||||||
- No change needed (device_ws router already registered).
|
- No change needed (device_ws router already registered).
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@@ -34,6 +35,11 @@ class BaseAgent(ABC):
|
|||||||
class ChatAgent(BaseAgent):
|
class ChatAgent(BaseAgent):
|
||||||
"""Base class for LLM-powered chat agents."""
|
"""Base class for LLM-powered chat agents."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
||||||
|
self.tool_results: list[dict] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
"""Process a user query and return a text response."""
|
"""Process a user query and return a text response."""
|
||||||
@@ -55,10 +61,16 @@ class ChatAgent(BaseAgent):
|
|||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
requesting tool calls or *max_iter* is reached, and returns the
|
||||||
final text response.
|
final text response. Captures raw execute_on_client results in
|
||||||
|
``self.tool_results``.
|
||||||
"""
|
"""
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
for _ in range(max_iter):
|
for _ in range(max_iter):
|
||||||
@@ -83,6 +95,64 @@ class ChatAgent(BaseAgent):
|
|||||||
# Exhausted iterations — ask model for a final answer without tools
|
# Exhausted iterations — ask model for a final answer without tools
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
return str(response.content)
|
return str(response.content)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
self.tool_results = collector
|
||||||
|
|
||||||
|
async def _tool_loop_stream(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
messages: list[Any],
|
||||||
|
tools: list[Any],
|
||||||
|
max_iter: int = 5,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Streaming variant of ``_tool_loop``.
|
||||||
|
|
||||||
|
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
||||||
|
tool calls). For the final response — when the model produces no further
|
||||||
|
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
||||||
|
Captures raw execute_on_client results in ``self.tool_results``.
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
|
for _ in range(max_iter):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
# Stream the final answer — don't keep the ainvoke result.
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
if chunk.content:
|
||||||
|
yield str(chunk.content)
|
||||||
|
return
|
||||||
|
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
# Execute each requested tool call
|
||||||
|
tool_map = {t.name: t for t in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_fn = tool_map.get(call["name"])
|
||||||
|
if tool_fn is None:
|
||||||
|
result = f"Unknown tool: {call['name']}"
|
||||||
|
else:
|
||||||
|
result = await tool_fn.ainvoke(call["args"])
|
||||||
|
messages.append(
|
||||||
|
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exhausted iterations — stream a final answer without tools
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
if chunk.content:
|
||||||
|
yield str(chunk.content)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
self.tool_results = collector
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
class AgentRegistry:
|
||||||
|
|||||||
@@ -17,6 +17,22 @@ _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = Cont
|
|||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optional collector that captures raw execute_on_client results.
|
||||||
|
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
"""Register *lst* as the collector for this async context."""
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
"""Clear the collector (best-effort)."""
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||||
@@ -65,4 +81,8 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
return await callback(payload)
|
result = await callback(payload)
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None:
|
||||||
|
collector.append(result)
|
||||||
|
return result
|
||||||
|
|||||||
416
tests/test_agent_streaming.py
Normal file
416
tests/test_agent_streaming.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
|
||||||
|
|
||||||
|
# ── Minimal concrete agent for testing ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _EchoAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "_echo"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Echo agent for tests"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
|
||||||
|
msg = AIMessage(content=content)
|
||||||
|
if tool_calls:
|
||||||
|
msg.tool_calls = tool_calls
|
||||||
|
else:
|
||||||
|
msg.tool_calls = []
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool(name: str, return_value: Any) -> MagicMock:
|
||||||
|
t = MagicMock()
|
||||||
|
t.name = name
|
||||||
|
t.ainvoke = AsyncMock(return_value=return_value)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
|
||||||
|
chunks = []
|
||||||
|
for tok in tokens:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
chunks.append(c)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
|
||||||
|
tokens: list[str] = []
|
||||||
|
async for tok in agent._tool_loop_stream(llm, messages, tools):
|
||||||
|
tokens.append(tok)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
# ── tool_results initialised ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_results_init():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: no tool calls ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_no_tools():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
||||||
|
assert result == "Hello!"
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: with one tool call + result capture ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_captures_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
# Mock execute_on_client to return structured data via the tool
|
||||||
|
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
|
||||||
|
|
||||||
|
async def fake_executor(payload: dict) -> dict:
|
||||||
|
return raw_result
|
||||||
|
|
||||||
|
# AIMessage with a tool call, then a final answer
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
final_msg = _make_ai_message("Here are your tasks.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||||
|
|
||||||
|
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
|
||||||
|
|
||||||
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||||
|
set_client_executor(fake_executor)
|
||||||
|
try:
|
||||||
|
# Patch the tool to actually call execute_on_client
|
||||||
|
async def tool_side_effect(args: dict) -> str:
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
res = await execute_on_client(action="select", table="tasks")
|
||||||
|
rows = res.get("rows", [])
|
||||||
|
return "\n".join(r["title"] for r in rows)
|
||||||
|
|
||||||
|
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
|
||||||
|
result = await agent._tool_loop(
|
||||||
|
llm, [HumanMessage(content="list my tasks")], [mock_tool]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
assert result == "Here are your tasks."
|
||||||
|
assert len(agent.tool_results) == 1
|
||||||
|
assert agent.tool_results[0] == raw_result
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: tool_results reset on each call ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_resets_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
|
||||||
|
|
||||||
|
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: unknown tool name ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_unknown_tool():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
# No known tools — model still calls a non-existent one; loop handles gracefully
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
final_msg = _make_ai_message("Handled.")
|
||||||
|
|
||||||
|
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
|
||||||
|
assert result == "Handled."
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_max_iter():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
always_tool = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
fallback = _make_ai_message("Fallback.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
# Returns tool_call_msg on every iteration
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=fallback)
|
||||||
|
|
||||||
|
mock_tool = _make_tool("t", "ok")
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
|
||||||
|
assert result == "Fallback."
|
||||||
|
assert llm_with_tools.ainvoke.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_no_tools_yields_tokens():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
|
||||||
|
no_tool_msg = _make_ai_message("irrelevant")
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
for tok in ["Hello", " ", "world"]:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
|
||||||
|
assert tokens == ["Hello", " ", "world"]
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: one tool call then streaming final ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_with_tool_call():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
|
||||||
|
|
||||||
|
async def fake_executor(payload: dict) -> dict:
|
||||||
|
return raw_result
|
||||||
|
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
# After tools run, ainvoke returns no more tool calls
|
||||||
|
no_more_tools_msg = _make_ai_message("Task found.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
for tok in ["Task", " ", "found."]:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
async def tool_side_effect(args: dict) -> str:
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
|
||||||
|
return res.get("row", {}).get("title", "")
|
||||||
|
|
||||||
|
mock_tool = _make_tool("get_task", "Deploy")
|
||||||
|
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
|
||||||
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||||
|
set_client_executor(fake_executor)
|
||||||
|
try:
|
||||||
|
tokens = await _collect_stream(
|
||||||
|
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
assert tokens == ["Task", " ", "found."]
|
||||||
|
assert len(agent.tool_results) == 1
|
||||||
|
assert agent.tool_results[0] == raw_result
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: tool_results reset on each call ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_resets_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
agent.tool_results = [{"old": True}]
|
||||||
|
|
||||||
|
no_tool_msg = _make_ai_message("")
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = "ok"
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
||||||
|
assert agent.tool_results == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_skips_empty_chunks():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
no_tool_msg = _make_ai_message("")
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
for tok in ["", "hello", "", " world", ""]:
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = tok
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
||||||
|
assert tokens == ["hello", " world"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_max_iter():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
always_tool = _make_ai_message(
|
||||||
|
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = "fallback"
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
mock_tool = _make_tool("t", "ok")
|
||||||
|
|
||||||
|
tokens = await _collect_stream(
|
||||||
|
agent, llm, [HumanMessage(content="x")], [mock_tool],
|
||||||
|
)
|
||||||
|
assert tokens == ["fallback"]
|
||||||
|
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
|
||||||
|
|
||||||
|
|
||||||
|
# ── _tool_loop_stream: multiple tool results captured ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_loop_stream_multiple_tool_results():
|
||||||
|
agent = _EchoAgent()
|
||||||
|
|
||||||
|
call_results = [
|
||||||
|
{"rows": [{"id": "t-1"}]},
|
||||||
|
{"rows": [{"id": "t-2"}]},
|
||||||
|
]
|
||||||
|
call_iter = iter(call_results)
|
||||||
|
|
||||||
|
async def fake_executor(payload: dict) -> dict:
|
||||||
|
return next(call_iter)
|
||||||
|
|
||||||
|
# Two tool calls in one iteration
|
||||||
|
tool_call_msg = _make_ai_message(
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
|
||||||
|
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
no_more_tools_msg = _make_ai_message("Done.")
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm_with_tools = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
||||||
|
|
||||||
|
async def fake_astream(msgs):
|
||||||
|
c = MagicMock()
|
||||||
|
c.content = "Done."
|
||||||
|
yield c
|
||||||
|
|
||||||
|
llm.astream = fake_astream
|
||||||
|
|
||||||
|
async def tool_side_effect(args: dict) -> str:
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
res = await execute_on_client(action="select", table="tasks")
|
||||||
|
return str(res)
|
||||||
|
|
||||||
|
tool_a = _make_tool("tool_a", "")
|
||||||
|
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
tool_b = _make_tool("tool_b", "")
|
||||||
|
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||||
|
|
||||||
|
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||||
|
set_client_executor(fake_executor)
|
||||||
|
try:
|
||||||
|
tokens = await _collect_stream(
|
||||||
|
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
assert tokens == ["Done."]
|
||||||
|
assert len(agent.tool_results) == 2
|
||||||
|
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
|
||||||
|
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
|
||||||
Reference in New Issue
Block a user