Files
api/app/core/agent_registry.py
2026-03-02 00:03:42 +01:00

138 lines
4.3 KiB
Python

"""Agent Registry — base classes and singleton registry for chat agents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for all agents."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
"""Override in subclasses to advertise capabilities."""
return []
class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents."""
@abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response."""
...
@abstractmethod
def get_tools(self) -> list[Any]:
"""Return LangChain tool definitions available to this agent."""
...
async def _tool_loop(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> str:
"""Shared tool-calling loop.
Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the
final text response.
"""
from langchain_core.messages import AIMessage, ToolMessage
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return str(response.content)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages)
return str(response.content)
class AgentRegistry:
"""Singleton registry for ChatAgent subclasses."""
_instance: AgentRegistry | None = None
def __init__(self) -> None:
self._agents: dict[str, type[ChatAgent]] = {}
def __new__(cls) -> AgentRegistry:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._agents = {}
return cls._instance
# ── public API ───────────────────────────────────────────────────
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
"""Class decorator — registers an agent by its name."""
instance = agent_class()
name = instance.get_name()
self._agents[name] = agent_class
return agent_class
def get(self, name: str) -> ChatAgent:
"""Return a fresh instance of the named agent."""
cls = self._agents.get(name)
if cls is None:
raise KeyError(f"Agent not found: {name}")
return cls()
def list_agents(self) -> list[dict[str, str]]:
"""Return ``[{name, description}]`` for the orchestrator prompt."""
result: list[dict[str, str]] = []
for cls in self._agents.values():
inst = cls()
result.append(
{"name": inst.get_name(), "description": inst.get_description()}
)
return result
async def call_agent(
self, name: str, query: str, context: dict[str, Any]
) -> str:
"""Instantiate the named agent and call its ``handle`` method."""
agent = self.get(name)
return await agent.handle(query, context)
# Module-level singleton
registry = AgentRegistry()