step 3 complete: pluggable agent framework
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -112,8 +112,8 @@ adiuva-api/
|
||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
||||
- **Outcome:** All request/response models defined and validated.
|
||||
|
||||
### Step 3 — Agent Registry + base classes
|
||||
- [ ] `app/core/agent_registry.py`:
|
||||
### Step 3 — Agent Registry + base classes ✅
|
||||
- [x] `app/core/agent_registry.py`:
|
||||
- `BaseAgent(ABC)`:
|
||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
||||
@@ -127,7 +127,7 @@ adiuva-api/
|
||||
- `get(name) -> ChatAgent`
|
||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
||||
- [ ] Unit tests: register, get, list, call_agent with mock
|
||||
- [x] Unit tests: register, get, list, call_agent with mock
|
||||
- **Outcome:** Pluggable agent framework.
|
||||
|
||||
### Step 4 — Orchestrator
|
||||
|
||||
137
app/core/agent_registry.py
Normal file
137
app/core/agent_registry.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for all agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "",
|
||||
shared_memory: dict[str, Any] | None = None,
|
||||
vector_store_context: list[str] | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||
self.vector_store_context: list[str] = vector_store_context or []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
"""Override in subclasses to advertise capabilities."""
|
||||
return []
|
||||
|
||||
|
||||
class ChatAgent(BaseAgent):
|
||||
"""Base class for LLM-powered chat agents."""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
"""Process a user query and return a text response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Any]:
|
||||
"""Return LangChain tool definitions available to this agent."""
|
||||
...
|
||||
|
||||
async def _tool_loop(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> str:
|
||||
"""Shared tool-calling loop.
|
||||
|
||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||
requesting tool calls or *max_iter* is reached, and returns the
|
||||
final text response.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return str(response.content)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — ask model for a final answer without tools
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Singleton registry for ChatAgent subclasses."""
|
||||
|
||||
_instance: AgentRegistry | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._agents: dict[str, type[ChatAgent]] = {}
|
||||
|
||||
def __new__(cls) -> AgentRegistry:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._agents = {}
|
||||
return cls._instance
|
||||
|
||||
# ── public API ───────────────────────────────────────────────────
|
||||
|
||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
||||
"""Class decorator — registers an agent by its name."""
|
||||
instance = agent_class()
|
||||
name = instance.get_name()
|
||||
self._agents[name] = agent_class
|
||||
return agent_class
|
||||
|
||||
def get(self, name: str) -> ChatAgent:
|
||||
"""Return a fresh instance of the named agent."""
|
||||
cls = self._agents.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"Agent not found: {name}")
|
||||
return cls()
|
||||
|
||||
def list_agents(self) -> list[dict[str, str]]:
|
||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
||||
result: list[dict[str, str]] = []
|
||||
for cls in self._agents.values():
|
||||
inst = cls()
|
||||
result.append(
|
||||
{"name": inst.get_name(), "description": inst.get_description()}
|
||||
)
|
||||
return result
|
||||
|
||||
async def call_agent(
|
||||
self, name: str, query: str, context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Instantiate the named agent and call its ``handle`` method."""
|
||||
agent = self.get(name)
|
||||
return await agent.handle(query, context)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = AgentRegistry()
|
||||
214
tests/test_agent_registry.py
Normal file
214
tests/test_agent_registry.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
class _StubAgent(ChatAgent):
|
||||
"""Minimal concrete agent for testing."""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "stub"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "A stub agent for tests"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"echo: {query}"
|
||||
|
||||
|
||||
class _AnotherAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "another"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Another stub"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return "another"
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_registry():
|
||||
"""Reset the singleton between tests."""
|
||||
AgentRegistry._instance = None
|
||||
yield
|
||||
AgentRegistry._instance = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def reg() -> AgentRegistry:
|
||||
return AgentRegistry()
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRegisterAndGet:
|
||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
agent = reg.get("stub")
|
||||
assert isinstance(agent, _StubAgent)
|
||||
|
||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
reg.get("nonexistent")
|
||||
|
||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
reg.register(_AnotherAgent)
|
||||
assert reg.get("stub").get_name() == "stub"
|
||||
assert reg.get("another").get_name() == "another"
|
||||
|
||||
|
||||
class TestListAgents:
|
||||
def test_empty(self, reg: AgentRegistry) -> None:
|
||||
assert reg.list_agents() == []
|
||||
|
||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
agents = reg.list_agents()
|
||||
assert len(agents) == 1
|
||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
||||
|
||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
reg.register(_AnotherAgent)
|
||||
names = {a["name"] for a in reg.list_agents()}
|
||||
assert names == {"stub", "another"}
|
||||
|
||||
|
||||
class TestCallAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
result = await reg.call_agent("stub", "hello", {})
|
||||
assert result == "echo: hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await reg.call_agent("nope", "hi", {})
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
def test_singleton_identity(self) -> None:
|
||||
a = AgentRegistry()
|
||||
b = AgentRegistry()
|
||||
assert a is b
|
||||
|
||||
|
||||
class TestToolLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_calls(self) -> None:
|
||||
"""When the LLM responds without tool calls, return content directly."""
|
||||
agent = _StubAgent()
|
||||
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.content = "final answer"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
||||
|
||||
result = await agent._tool_loop(llm, [], [])
|
||||
assert result == "final answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_then_answer(self) -> None:
|
||||
"""LLM requests one tool call, gets result, then answers."""
|
||||
agent = _StubAgent()
|
||||
|
||||
# First response: tool call
|
||||
tool_call_msg = MagicMock()
|
||||
tool_call_msg.content = ""
|
||||
tool_call_msg.tool_calls = [
|
||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
||||
]
|
||||
|
||||
# Second response: final answer
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "done"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
# Mock tool
|
||||
tool = AsyncMock()
|
||||
tool.name = "my_tool"
|
||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
||||
|
||||
result = await agent._tool_loop(llm, [], [tool])
|
||||
assert result == "done"
|
||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_handled(self) -> None:
|
||||
"""Unknown tool names produce an error message instead of crashing."""
|
||||
agent = _StubAgent()
|
||||
|
||||
tool_call_msg = MagicMock()
|
||||
tool_call_msg.content = ""
|
||||
tool_call_msg.tool_calls = [
|
||||
{"id": "call_1", "name": "missing", "args": {}}
|
||||
]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "recovered"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
result = await agent._tool_loop(llm, [], [])
|
||||
assert result == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iter_reached(self) -> None:
|
||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
||||
agent = _StubAgent()
|
||||
|
||||
# Every response requests a tool call
|
||||
loop_msg = MagicMock()
|
||||
loop_msg.content = ""
|
||||
loop_msg.tool_calls = [
|
||||
{"id": "call_x", "name": "t", "args": {}}
|
||||
]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "gave up"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
tool = AsyncMock()
|
||||
tool.name = "t"
|
||||
tool.ainvoke = AsyncMock(return_value="ok")
|
||||
|
||||
llm_with_tools = AsyncMock()
|
||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||
|
||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
||||
assert result == "gave up"
|
||||
assert llm_with_tools.ainvoke.call_count == 2
|
||||
Reference in New Issue
Block a user