215 lines
6.7 KiB
Python
215 lines
6.7 KiB
Python
"""Unit tests for the agent registry, base classes, and tool loop."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────
|
|
|
|
class _StubAgent(ChatAgent):
|
|
"""Minimal concrete agent for testing."""
|
|
|
|
def get_name(self) -> str:
|
|
return "stub"
|
|
|
|
def get_description(self) -> str:
|
|
return "A stub agent for tests"
|
|
|
|
def get_tools(self) -> list[Any]:
|
|
return []
|
|
|
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
return f"echo: {query}"
|
|
|
|
|
|
class _AnotherAgent(ChatAgent):
|
|
def get_name(self) -> str:
|
|
return "another"
|
|
|
|
def get_description(self) -> str:
|
|
return "Another stub"
|
|
|
|
def get_tools(self) -> list[Any]:
|
|
return []
|
|
|
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
return "another"
|
|
|
|
|
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _fresh_registry():
|
|
"""Reset the singleton between tests."""
|
|
AgentRegistry._instance = None
|
|
yield
|
|
AgentRegistry._instance = None
|
|
|
|
|
|
@pytest.fixture()
|
|
def reg() -> AgentRegistry:
|
|
return AgentRegistry()
|
|
|
|
|
|
# ── Tests ────────────────────────────────────────────────────────────
|
|
|
|
class TestRegisterAndGet:
|
|
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
|
reg.register(_StubAgent)
|
|
agent = reg.get("stub")
|
|
assert isinstance(agent, _StubAgent)
|
|
|
|
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
with pytest.raises(KeyError, match="not found"):
|
|
reg.get("nonexistent")
|
|
|
|
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
|
reg.register(_StubAgent)
|
|
reg.register(_AnotherAgent)
|
|
assert reg.get("stub").get_name() == "stub"
|
|
assert reg.get("another").get_name() == "another"
|
|
|
|
|
|
class TestListAgents:
|
|
def test_empty(self, reg: AgentRegistry) -> None:
|
|
assert reg.list_agents() == []
|
|
|
|
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
|
reg.register(_StubAgent)
|
|
agents = reg.list_agents()
|
|
assert len(agents) == 1
|
|
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
|
|
|
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
|
reg.register(_StubAgent)
|
|
reg.register(_AnotherAgent)
|
|
names = {a["name"] for a in reg.list_agents()}
|
|
assert names == {"stub", "another"}
|
|
|
|
|
|
class TestCallAgent:
|
|
@pytest.mark.asyncio
|
|
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
|
reg.register(_StubAgent)
|
|
result = await reg.call_agent("stub", "hello", {})
|
|
assert result == "echo: hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
with pytest.raises(KeyError):
|
|
await reg.call_agent("nope", "hi", {})
|
|
|
|
|
|
class TestSingleton:
|
|
def test_singleton_identity(self) -> None:
|
|
a = AgentRegistry()
|
|
b = AgentRegistry()
|
|
assert a is b
|
|
|
|
|
|
class TestToolLoop:
|
|
@pytest.mark.asyncio
|
|
async def test_no_tool_calls(self) -> None:
|
|
"""When the LLM responds without tool calls, return content directly."""
|
|
agent = _StubAgent()
|
|
|
|
ai_msg = MagicMock()
|
|
ai_msg.content = "final answer"
|
|
ai_msg.tool_calls = []
|
|
|
|
llm = AsyncMock()
|
|
llm.bind_tools = MagicMock(return_value=llm)
|
|
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
|
|
|
result = await agent._tool_loop(llm, [], [])
|
|
assert result == "final answer"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_then_answer(self) -> None:
|
|
"""LLM requests one tool call, gets result, then answers."""
|
|
agent = _StubAgent()
|
|
|
|
# First response: tool call
|
|
tool_call_msg = MagicMock()
|
|
tool_call_msg.content = ""
|
|
tool_call_msg.tool_calls = [
|
|
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
|
]
|
|
|
|
# Second response: final answer
|
|
final_msg = MagicMock()
|
|
final_msg.content = "done"
|
|
final_msg.tool_calls = []
|
|
|
|
llm = AsyncMock()
|
|
llm.bind_tools = MagicMock(return_value=llm)
|
|
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
|
|
# Mock tool
|
|
tool = AsyncMock()
|
|
tool.name = "my_tool"
|
|
tool.ainvoke = AsyncMock(return_value="tool_result")
|
|
|
|
result = await agent._tool_loop(llm, [], [tool])
|
|
assert result == "done"
|
|
tool.ainvoke.assert_called_once_with({"x": 1})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_tool_handled(self) -> None:
|
|
"""Unknown tool names produce an error message instead of crashing."""
|
|
agent = _StubAgent()
|
|
|
|
tool_call_msg = MagicMock()
|
|
tool_call_msg.content = ""
|
|
tool_call_msg.tool_calls = [
|
|
{"id": "call_1", "name": "missing", "args": {}}
|
|
]
|
|
|
|
final_msg = MagicMock()
|
|
final_msg.content = "recovered"
|
|
final_msg.tool_calls = []
|
|
|
|
llm = AsyncMock()
|
|
llm.bind_tools = MagicMock(return_value=llm)
|
|
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
|
|
result = await agent._tool_loop(llm, [], [])
|
|
assert result == "recovered"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_iter_reached(self) -> None:
|
|
"""When max iterations are exhausted, a final no-tools call is made."""
|
|
agent = _StubAgent()
|
|
|
|
# Every response requests a tool call
|
|
loop_msg = MagicMock()
|
|
loop_msg.content = ""
|
|
loop_msg.tool_calls = [
|
|
{"id": "call_x", "name": "t", "args": {}}
|
|
]
|
|
|
|
final_msg = MagicMock()
|
|
final_msg.content = "gave up"
|
|
final_msg.tool_calls = []
|
|
|
|
tool = AsyncMock()
|
|
tool.name = "t"
|
|
tool.ainvoke = AsyncMock(return_value="ok")
|
|
|
|
llm_with_tools = AsyncMock()
|
|
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
|
|
|
llm = AsyncMock()
|
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
|
|
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
|
assert result == "gave up"
|
|
assert llm_with_tools.ainvoke.call_count == 2
|