step 3 complete: pluggable agent framework
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -112,8 +112,8 @@ adiuva-api/
|
|||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
||||||
- **Outcome:** All request/response models defined and validated.
|
- **Outcome:** All request/response models defined and validated.
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes
|
### Step 3 — Agent Registry + base classes ✅
|
||||||
- [ ] `app/core/agent_registry.py`:
|
- [x] `app/core/agent_registry.py`:
|
||||||
- `BaseAgent(ABC)`:
|
- `BaseAgent(ABC)`:
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
- Abstract `get_name() -> str`, `get_description() -> str`
|
||||||
@@ -127,7 +127,7 @@ adiuva-api/
|
|||||||
- `get(name) -> ChatAgent`
|
- `get(name) -> ChatAgent`
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
||||||
- [ ] Unit tests: register, get, list, call_agent with mock
|
- [x] Unit tests: register, get, list, call_agent with mock
|
||||||
- **Outcome:** Pluggable agent framework.
|
- **Outcome:** Pluggable agent framework.
|
||||||
|
|
||||||
### Step 4 — Orchestrator
|
### Step 4 — Orchestrator
|
||||||
|
|||||||
137
app/core/agent_registry.py
Normal file
137
app/core/agent_registry.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Agent Registry — base classes and singleton registry for chat agents."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""Common base for all agents."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str = "",
|
||||||
|
shared_memory: dict[str, Any] | None = None,
|
||||||
|
vector_store_context: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.user_id = user_id
|
||||||
|
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||||
|
self.vector_store_context: list[str] = vector_store_context or []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_description(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skills(self) -> list[str]:
|
||||||
|
"""Override in subclasses to advertise capabilities."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class ChatAgent(BaseAgent):
|
||||||
|
"""Base class for LLM-powered chat agents."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
"""Process a user query and return a text response."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
"""Return LangChain tool definitions available to this agent."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def _tool_loop(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
messages: list[Any],
|
||||||
|
tools: list[Any],
|
||||||
|
max_iter: int = 5,
|
||||||
|
) -> str:
|
||||||
|
"""Shared tool-calling loop.
|
||||||
|
|
||||||
|
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||||
|
requesting tool calls or *max_iter* is reached, and returns the
|
||||||
|
final text response.
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
|
for _ in range(max_iter):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return str(response.content)
|
||||||
|
|
||||||
|
# Execute each requested tool call
|
||||||
|
tool_map = {t.name: t for t in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_fn = tool_map.get(call["name"])
|
||||||
|
if tool_fn is None:
|
||||||
|
result = f"Unknown tool: {call['name']}"
|
||||||
|
else:
|
||||||
|
result = await tool_fn.ainvoke(call["args"])
|
||||||
|
messages.append(
|
||||||
|
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exhausted iterations — ask model for a final answer without tools
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
return str(response.content)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Singleton registry for ChatAgent subclasses."""
|
||||||
|
|
||||||
|
_instance: AgentRegistry | None = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._agents: dict[str, type[ChatAgent]] = {}
|
||||||
|
|
||||||
|
def __new__(cls) -> AgentRegistry:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._agents = {}
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
# ── public API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
||||||
|
"""Class decorator — registers an agent by its name."""
|
||||||
|
instance = agent_class()
|
||||||
|
name = instance.get_name()
|
||||||
|
self._agents[name] = agent_class
|
||||||
|
return agent_class
|
||||||
|
|
||||||
|
def get(self, name: str) -> ChatAgent:
|
||||||
|
"""Return a fresh instance of the named agent."""
|
||||||
|
cls = self._agents.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise KeyError(f"Agent not found: {name}")
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def list_agents(self) -> list[dict[str, str]]:
|
||||||
|
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
||||||
|
result: list[dict[str, str]] = []
|
||||||
|
for cls in self._agents.values():
|
||||||
|
inst = cls()
|
||||||
|
result.append(
|
||||||
|
{"name": inst.get_name(), "description": inst.get_description()}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def call_agent(
|
||||||
|
self, name: str, query: str, context: dict[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""Instantiate the named agent and call its ``handle`` method."""
|
||||||
|
agent = self.get(name)
|
||||||
|
return await agent.handle(query, context)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = AgentRegistry()
|
||||||
214
tests/test_agent_registry.py
Normal file
214
tests/test_agent_registry.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Unit tests for the agent registry, base classes, and tool loop."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _StubAgent(ChatAgent):
|
||||||
|
"""Minimal concrete agent for testing."""
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "stub"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "A stub agent for tests"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return f"echo: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
class _AnotherAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "another"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Another stub"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return "another"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _fresh_registry():
|
||||||
|
"""Reset the singleton between tests."""
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
yield
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def reg() -> AgentRegistry:
|
||||||
|
return AgentRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRegisterAndGet:
|
||||||
|
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
agent = reg.get("stub")
|
||||||
|
assert isinstance(agent, _StubAgent)
|
||||||
|
|
||||||
|
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError, match="not found"):
|
||||||
|
reg.get("nonexistent")
|
||||||
|
|
||||||
|
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
reg.register(_AnotherAgent)
|
||||||
|
assert reg.get("stub").get_name() == "stub"
|
||||||
|
assert reg.get("another").get_name() == "another"
|
||||||
|
|
||||||
|
|
||||||
|
class TestListAgents:
|
||||||
|
def test_empty(self, reg: AgentRegistry) -> None:
|
||||||
|
assert reg.list_agents() == []
|
||||||
|
|
||||||
|
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
agents = reg.list_agents()
|
||||||
|
assert len(agents) == 1
|
||||||
|
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
||||||
|
|
||||||
|
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
reg.register(_AnotherAgent)
|
||||||
|
names = {a["name"] for a in reg.list_agents()}
|
||||||
|
assert names == {"stub", "another"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallAgent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
result = await reg.call_agent("stub", "hello", {})
|
||||||
|
assert result == "echo: hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await reg.call_agent("nope", "hi", {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingleton:
|
||||||
|
def test_singleton_identity(self) -> None:
|
||||||
|
a = AgentRegistry()
|
||||||
|
b = AgentRegistry()
|
||||||
|
assert a is b
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolLoop:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_calls(self) -> None:
|
||||||
|
"""When the LLM responds without tool calls, return content directly."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.content = "final answer"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [])
|
||||||
|
assert result == "final answer"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_then_answer(self) -> None:
|
||||||
|
"""LLM requests one tool call, gets result, then answers."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
# First response: tool call
|
||||||
|
tool_call_msg = MagicMock()
|
||||||
|
tool_call_msg.content = ""
|
||||||
|
tool_call_msg.tool_calls = [
|
||||||
|
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Second response: final answer
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = "done"
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm)
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
|
||||||
|
# Mock tool
|
||||||
|
tool = AsyncMock()
|
||||||
|
tool.name = "my_tool"
|
||||||
|
tool.ainvoke = AsyncMock(return_value="tool_result")
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [tool])
|
||||||
|
assert result == "done"
|
||||||
|
tool.ainvoke.assert_called_once_with({"x": 1})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_tool_handled(self) -> None:
|
||||||
|
"""Unknown tool names produce an error message instead of crashing."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
tool_call_msg = MagicMock()
|
||||||
|
tool_call_msg.content = ""
|
||||||
|
tool_call_msg.tool_calls = [
|
||||||
|
{"id": "call_1", "name": "missing", "args": {}}
|
||||||
|
]
|
||||||
|
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = "recovered"
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm)
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [])
|
||||||
|
assert result == "recovered"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_iter_reached(self) -> None:
|
||||||
|
"""When max iterations are exhausted, a final no-tools call is made."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
# Every response requests a tool call
|
||||||
|
loop_msg = MagicMock()
|
||||||
|
loop_msg.content = ""
|
||||||
|
loop_msg.tool_calls = [
|
||||||
|
{"id": "call_x", "name": "t", "args": {}}
|
||||||
|
]
|
||||||
|
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = "gave up"
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
tool = AsyncMock()
|
||||||
|
tool.name = "t"
|
||||||
|
tool.ainvoke = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
llm_with_tools = AsyncMock()
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
||||||
|
assert result == "gave up"
|
||||||
|
assert llm_with_tools.ainvoke.call_count == 2
|
||||||
Reference in New Issue
Block a user