"""Agent Registry — base classes and singleton registry for chat agents.""" from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from typing import Any class BaseAgent(ABC): """Common base for all agents.""" def __init__( self, user_id: str = "", shared_memory: dict[str, Any] | None = None, vector_store_context: list[str] | None = None, ) -> None: self.user_id = user_id self.shared_memory: dict[str, Any] = shared_memory or {} self.vector_store_context: list[str] = vector_store_context or [] @abstractmethod def get_name(self) -> str: ... @abstractmethod def get_description(self) -> str: ... @property def skills(self) -> list[str]: """Override in subclasses to advertise capabilities.""" return [] class ChatAgent(BaseAgent): """Base class for LLM-powered chat agents.""" def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) # Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results. self.tool_results: list[dict] = [] @abstractmethod async def handle(self, query: str, context: dict[str, Any]) -> str: """Process a user query and return a text response.""" ... @abstractmethod def get_tools(self) -> list[Any]: """Return LangChain tool definitions available to this agent.""" ... async def _tool_loop( self, llm: Any, messages: list[Any], tools: list[Any], max_iter: int = 5, ) -> str: """Shared tool-calling loop. Binds *tools* to *llm*, invokes iteratively until the model stops requesting tool calls or *max_iter* is reached, and returns the final text response. Captures raw execute_on_client results in ``self.tool_results``. """ from langchain_core.messages import AIMessage, ToolMessage from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector collector: list[dict] = [] set_tool_result_collector(collector) try: llm_with_tools = llm.bind_tools(tools) if tools else llm for _ in range(max_iter): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: return str(response.content) # Execute each requested tool call tool_map = {t.name: t for t in tools} for call in response.tool_calls: tool_fn = tool_map.get(call["name"]) if tool_fn is None: result = f"Unknown tool: {call['name']}" else: result = await tool_fn.ainvoke(call["args"]) messages.append( ToolMessage(content=str(result), tool_call_id=call["id"]) ) # Exhausted iterations — ask model for a final answer without tools response = await llm.ainvoke(messages) return str(response.content) finally: clear_tool_result_collector() self.tool_results = collector async def _tool_loop_stream( self, llm: Any, messages: list[Any], tools: list[Any], max_iter: int = 5, ) -> AsyncGenerator[str, None]: """Streaming variant of ``_tool_loop``. Behaves identically for tool-calling iterations (uses ainvoke to parse tool calls). For the final response — when the model produces no further tool calls — switches to ``llm.astream()`` and yields text tokens. Captures raw execute_on_client results in ``self.tool_results``. """ from langchain_core.messages import AIMessage, ToolMessage from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector collector: list[dict] = [] set_tool_result_collector(collector) try: llm_with_tools = llm.bind_tools(tools) if tools else llm for _ in range(max_iter): response: AIMessage = await llm_with_tools.ainvoke(messages) if not response.tool_calls: # Stream the final answer — don't keep the ainvoke result. async for chunk in llm.astream(messages): if chunk.content: yield str(chunk.content) return messages.append(response) # Execute each requested tool call tool_map = {t.name: t for t in tools} for call in response.tool_calls: tool_fn = tool_map.get(call["name"]) if tool_fn is None: result = f"Unknown tool: {call['name']}" else: result = await tool_fn.ainvoke(call["args"]) messages.append( ToolMessage(content=str(result), tool_call_id=call["id"]) ) # Exhausted iterations — stream a final answer without tools async for chunk in llm.astream(messages): if chunk.content: yield str(chunk.content) finally: clear_tool_result_collector() self.tool_results = collector class AgentRegistry: """Singleton registry for ChatAgent subclasses.""" _instance: AgentRegistry | None = None def __init__(self) -> None: self._agents: dict[str, type[ChatAgent]] = {} def __new__(cls) -> AgentRegistry: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._agents = {} return cls._instance # ── public API ─────────────────────────────────────────────────── def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]: """Class decorator — registers an agent by its name.""" instance = agent_class() name = instance.get_name() self._agents[name] = agent_class return agent_class def get(self, name: str) -> ChatAgent: """Return a fresh instance of the named agent.""" cls = self._agents.get(name) if cls is None: raise KeyError(f"Agent not found: {name}") return cls() def list_agents(self) -> list[dict[str, str]]: """Return ``[{name, description}]`` for the orchestrator prompt.""" result: list[dict[str, str]] = [] for cls in self._agents.values(): inst = cls() result.append( {"name": inst.get_name(), "description": inst.get_description()} ) return result async def call_agent( self, name: str, query: str, context: dict[str, Any] ) -> str: """Instantiate the named agent and call its ``handle`` method.""" agent = self.get(name) return await agent.handle(query, context) # Module-level singleton registry = AgentRegistry()