diff --git a/BACKEND_PLAN.md b/BACKEND_PLAN.md index c2d01ce..be8be32 100644 --- a/BACKEND_PLAN.md +++ b/BACKEND_PLAN.md @@ -112,8 +112,8 @@ adiuva-api/ - `UserProfile`: `id: str`, `email: str`, `tier: BillingTier` - **Outcome:** All request/response models defined and validated. -### Step 3 — Agent Registry + base classes -- [ ] `app/core/agent_registry.py`: +### Step 3 — Agent Registry + base classes ✅ +- [x] `app/core/agent_registry.py`: - `BaseAgent(ABC)`: - `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]` - Abstract `get_name() -> str`, `get_description() -> str` @@ -127,7 +127,7 @@ adiuva-api/ - `get(name) -> ChatAgent` - `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt - `async call_agent(name, query, context) -> str` — for inter-agent calls -- [ ] Unit tests: register, get, list, call_agent with mock +- [x] Unit tests: register, get, list, call_agent with mock - **Outcome:** Pluggable agent framework. ### Step 4 — Orchestrator diff --git a/app/core/agent_registry.py b/app/core/agent_registry.py new file mode 100644 index 0000000..1037c14 --- /dev/null +++ b/app/core/agent_registry.py @@ -0,0 +1,137 @@ +"""Agent Registry — base classes and singleton registry for chat agents.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseAgent(ABC): + """Common base for all agents.""" + + def __init__( + self, + user_id: str = "", + shared_memory: dict[str, Any] | None = None, + vector_store_context: list[str] | None = None, + ) -> None: + self.user_id = user_id + self.shared_memory: dict[str, Any] = shared_memory or {} + self.vector_store_context: list[str] = vector_store_context or [] + + @abstractmethod + def get_name(self) -> str: ... + + @abstractmethod + def get_description(self) -> str: ... + + @property + def skills(self) -> list[str]: + """Override in subclasses to advertise capabilities.""" + return [] + + +class ChatAgent(BaseAgent): + """Base class for LLM-powered chat agents.""" + + @abstractmethod + async def handle(self, query: str, context: dict[str, Any]) -> str: + """Process a user query and return a text response.""" + ... + + @abstractmethod + def get_tools(self) -> list[Any]: + """Return LangChain tool definitions available to this agent.""" + ... + + async def _tool_loop( + self, + llm: Any, + messages: list[Any], + tools: list[Any], + max_iter: int = 5, + ) -> str: + """Shared tool-calling loop. + + Binds *tools* to *llm*, invokes iteratively until the model stops + requesting tool calls or *max_iter* is reached, and returns the + final text response. + """ + from langchain_core.messages import AIMessage, ToolMessage + + llm_with_tools = llm.bind_tools(tools) if tools else llm + + for _ in range(max_iter): + response: AIMessage = await llm_with_tools.ainvoke(messages) + messages.append(response) + + if not response.tool_calls: + return str(response.content) + + # Execute each requested tool call + tool_map = {t.name: t for t in tools} + for call in response.tool_calls: + tool_fn = tool_map.get(call["name"]) + if tool_fn is None: + result = f"Unknown tool: {call['name']}" + else: + result = await tool_fn.ainvoke(call["args"]) + messages.append( + ToolMessage(content=str(result), tool_call_id=call["id"]) + ) + + # Exhausted iterations — ask model for a final answer without tools + response = await llm.ainvoke(messages) + return str(response.content) + + +class AgentRegistry: + """Singleton registry for ChatAgent subclasses.""" + + _instance: AgentRegistry | None = None + + def __init__(self) -> None: + self._agents: dict[str, type[ChatAgent]] = {} + + def __new__(cls) -> AgentRegistry: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._agents = {} + return cls._instance + + # ── public API ─────────────────────────────────────────────────── + + def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]: + """Class decorator — registers an agent by its name.""" + instance = agent_class() + name = instance.get_name() + self._agents[name] = agent_class + return agent_class + + def get(self, name: str) -> ChatAgent: + """Return a fresh instance of the named agent.""" + cls = self._agents.get(name) + if cls is None: + raise KeyError(f"Agent not found: {name}") + return cls() + + def list_agents(self) -> list[dict[str, str]]: + """Return ``[{name, description}]`` for the orchestrator prompt.""" + result: list[dict[str, str]] = [] + for cls in self._agents.values(): + inst = cls() + result.append( + {"name": inst.get_name(), "description": inst.get_description()} + ) + return result + + async def call_agent( + self, name: str, query: str, context: dict[str, Any] + ) -> str: + """Instantiate the named agent and call its ``handle`` method.""" + agent = self.get(name) + return await agent.handle(query, context) + + +# Module-level singleton +registry = AgentRegistry() diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py new file mode 100644 index 0000000..9fd9381 --- /dev/null +++ b/tests/test_agent_registry.py @@ -0,0 +1,214 @@ +"""Unit tests for the agent registry, base classes, and tool loop.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.core.agent_registry import AgentRegistry, ChatAgent + + +# ── Helpers ────────────────────────────────────────────────────────── + +class _StubAgent(ChatAgent): + """Minimal concrete agent for testing.""" + + def get_name(self) -> str: + return "stub" + + def get_description(self) -> str: + return "A stub agent for tests" + + def get_tools(self) -> list[Any]: + return [] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + return f"echo: {query}" + + +class _AnotherAgent(ChatAgent): + def get_name(self) -> str: + return "another" + + def get_description(self) -> str: + return "Another stub" + + def get_tools(self) -> list[Any]: + return [] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + return "another" + + +# ── Fixtures ───────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _fresh_registry(): + """Reset the singleton between tests.""" + AgentRegistry._instance = None + yield + AgentRegistry._instance = None + + +@pytest.fixture() +def reg() -> AgentRegistry: + return AgentRegistry() + + +# ── Tests ──────────────────────────────────────────────────────────── + +class TestRegisterAndGet: + def test_register_decorator(self, reg: AgentRegistry) -> None: + reg.register(_StubAgent) + agent = reg.get("stub") + assert isinstance(agent, _StubAgent) + + def test_get_unknown_raises(self, reg: AgentRegistry) -> None: + with pytest.raises(KeyError, match="not found"): + reg.get("nonexistent") + + def test_register_multiple(self, reg: AgentRegistry) -> None: + reg.register(_StubAgent) + reg.register(_AnotherAgent) + assert reg.get("stub").get_name() == "stub" + assert reg.get("another").get_name() == "another" + + +class TestListAgents: + def test_empty(self, reg: AgentRegistry) -> None: + assert reg.list_agents() == [] + + def test_list_after_register(self, reg: AgentRegistry) -> None: + reg.register(_StubAgent) + agents = reg.list_agents() + assert len(agents) == 1 + assert agents[0] == {"name": "stub", "description": "A stub agent for tests"} + + def test_list_multiple(self, reg: AgentRegistry) -> None: + reg.register(_StubAgent) + reg.register(_AnotherAgent) + names = {a["name"] for a in reg.list_agents()} + assert names == {"stub", "another"} + + +class TestCallAgent: + @pytest.mark.asyncio + async def test_call_agent(self, reg: AgentRegistry) -> None: + reg.register(_StubAgent) + result = await reg.call_agent("stub", "hello", {}) + assert result == "echo: hello" + + @pytest.mark.asyncio + async def test_call_unknown_raises(self, reg: AgentRegistry) -> None: + with pytest.raises(KeyError): + await reg.call_agent("nope", "hi", {}) + + +class TestSingleton: + def test_singleton_identity(self) -> None: + a = AgentRegistry() + b = AgentRegistry() + assert a is b + + +class TestToolLoop: + @pytest.mark.asyncio + async def test_no_tool_calls(self) -> None: + """When the LLM responds without tool calls, return content directly.""" + agent = _StubAgent() + + ai_msg = MagicMock() + ai_msg.content = "final answer" + ai_msg.tool_calls = [] + + llm = AsyncMock() + llm.bind_tools = MagicMock(return_value=llm) + llm.ainvoke = AsyncMock(return_value=ai_msg) + + result = await agent._tool_loop(llm, [], []) + assert result == "final answer" + + @pytest.mark.asyncio + async def test_tool_call_then_answer(self) -> None: + """LLM requests one tool call, gets result, then answers.""" + agent = _StubAgent() + + # First response: tool call + tool_call_msg = MagicMock() + tool_call_msg.content = "" + tool_call_msg.tool_calls = [ + {"id": "call_1", "name": "my_tool", "args": {"x": 1}} + ] + + # Second response: final answer + final_msg = MagicMock() + final_msg.content = "done" + final_msg.tool_calls = [] + + llm = AsyncMock() + llm.bind_tools = MagicMock(return_value=llm) + llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) + + # Mock tool + tool = AsyncMock() + tool.name = "my_tool" + tool.ainvoke = AsyncMock(return_value="tool_result") + + result = await agent._tool_loop(llm, [], [tool]) + assert result == "done" + tool.ainvoke.assert_called_once_with({"x": 1}) + + @pytest.mark.asyncio + async def test_unknown_tool_handled(self) -> None: + """Unknown tool names produce an error message instead of crashing.""" + agent = _StubAgent() + + tool_call_msg = MagicMock() + tool_call_msg.content = "" + tool_call_msg.tool_calls = [ + {"id": "call_1", "name": "missing", "args": {}} + ] + + final_msg = MagicMock() + final_msg.content = "recovered" + final_msg.tool_calls = [] + + llm = AsyncMock() + llm.bind_tools = MagicMock(return_value=llm) + llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) + + result = await agent._tool_loop(llm, [], []) + assert result == "recovered" + + @pytest.mark.asyncio + async def test_max_iter_reached(self) -> None: + """When max iterations are exhausted, a final no-tools call is made.""" + agent = _StubAgent() + + # Every response requests a tool call + loop_msg = MagicMock() + loop_msg.content = "" + loop_msg.tool_calls = [ + {"id": "call_x", "name": "t", "args": {}} + ] + + final_msg = MagicMock() + final_msg.content = "gave up" + final_msg.tool_calls = [] + + tool = AsyncMock() + tool.name = "t" + tool.ainvoke = AsyncMock(return_value="ok") + + llm_with_tools = AsyncMock() + llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg) + + llm = AsyncMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + llm.ainvoke = AsyncMock(return_value=final_msg) + + result = await agent._tool_loop(llm, [], [tool], max_iter=2) + assert result == "gave up" + assert llm_with_tools.ainvoke.call_count == 2