- orchestrate_v3(user_id, message, context): classifies intent, returns (agent_name, agent_instance) — caller drives execution - orchestrate_v3_stream(user_id, message, context): yields (agent_name, token) pairs; first yield is always (agent_name, "") as a domain-detection signal - ChatAgent.handle_stream(): default implementation yields handle() result as one chunk; subclasses override for true token-level streaming - Fix stale test_orchestrator.py assertions that expected a JSON final frame that orchestrate_stream never emitted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
218 lines
7.5 KiB
Python
218 lines
7.5 KiB
Python
"""Agent Registry — base classes and singleton registry for chat agents."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
"""Common base for all agents."""
|
|
|
|
def __init__(
|
|
self,
|
|
user_id: str = "",
|
|
shared_memory: dict[str, Any] | None = None,
|
|
vector_store_context: list[str] | None = None,
|
|
) -> None:
|
|
self.user_id = user_id
|
|
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
self.vector_store_context: list[str] = vector_store_context or []
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str: ...
|
|
|
|
@abstractmethod
|
|
def get_description(self) -> str: ...
|
|
|
|
@property
|
|
def skills(self) -> list[str]:
|
|
"""Override in subclasses to advertise capabilities."""
|
|
return []
|
|
|
|
|
|
class ChatAgent(BaseAgent):
|
|
"""Base class for LLM-powered chat agents."""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
super().__init__(**kwargs)
|
|
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
|
self.tool_results: list[dict] = []
|
|
|
|
@abstractmethod
|
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
"""Process a user query and return a text response."""
|
|
...
|
|
|
|
async def handle_stream(
|
|
self, query: str, context: dict[str, Any]
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Streaming variant of handle().
|
|
|
|
Default: calls handle() and yields the full response as one chunk.
|
|
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
|
"""
|
|
yield await self.handle(query, context)
|
|
|
|
@abstractmethod
|
|
def get_tools(self) -> list[Any]:
|
|
"""Return LangChain tool definitions available to this agent."""
|
|
...
|
|
|
|
async def _tool_loop(
|
|
self,
|
|
llm: Any,
|
|
messages: list[Any],
|
|
tools: list[Any],
|
|
max_iter: int = 5,
|
|
) -> str:
|
|
"""Shared tool-calling loop.
|
|
|
|
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
requesting tool calls or *max_iter* is reached, and returns the
|
|
final text response. Captures raw execute_on_client results in
|
|
``self.tool_results``.
|
|
"""
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
|
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
|
|
collector: list[dict] = []
|
|
set_tool_result_collector(collector)
|
|
try:
|
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
|
|
for _ in range(max_iter):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
return str(response.content)
|
|
|
|
# Execute each requested tool call
|
|
tool_map = {t.name: t for t in tools}
|
|
for call in response.tool_calls:
|
|
tool_fn = tool_map.get(call["name"])
|
|
if tool_fn is None:
|
|
result = f"Unknown tool: {call['name']}"
|
|
else:
|
|
result = await tool_fn.ainvoke(call["args"])
|
|
messages.append(
|
|
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
)
|
|
|
|
# Exhausted iterations — ask model for a final answer without tools
|
|
response = await llm.ainvoke(messages)
|
|
return str(response.content)
|
|
finally:
|
|
clear_tool_result_collector()
|
|
self.tool_results = collector
|
|
|
|
async def _tool_loop_stream(
|
|
self,
|
|
llm: Any,
|
|
messages: list[Any],
|
|
tools: list[Any],
|
|
max_iter: int = 5,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Streaming variant of ``_tool_loop``.
|
|
|
|
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
|
tool calls). For the final response — when the model produces no further
|
|
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
|
Captures raw execute_on_client results in ``self.tool_results``.
|
|
"""
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
|
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
|
|
collector: list[dict] = []
|
|
set_tool_result_collector(collector)
|
|
try:
|
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
|
|
for _ in range(max_iter):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
|
|
if not response.tool_calls:
|
|
# Stream the final answer — don't keep the ainvoke result.
|
|
async for chunk in llm.astream(messages):
|
|
if chunk.content:
|
|
yield str(chunk.content)
|
|
return
|
|
|
|
messages.append(response)
|
|
|
|
# Execute each requested tool call
|
|
tool_map = {t.name: t for t in tools}
|
|
for call in response.tool_calls:
|
|
tool_fn = tool_map.get(call["name"])
|
|
if tool_fn is None:
|
|
result = f"Unknown tool: {call['name']}"
|
|
else:
|
|
result = await tool_fn.ainvoke(call["args"])
|
|
messages.append(
|
|
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
)
|
|
|
|
# Exhausted iterations — stream a final answer without tools
|
|
async for chunk in llm.astream(messages):
|
|
if chunk.content:
|
|
yield str(chunk.content)
|
|
finally:
|
|
clear_tool_result_collector()
|
|
self.tool_results = collector
|
|
|
|
|
|
class AgentRegistry:
|
|
"""Singleton registry for ChatAgent subclasses."""
|
|
|
|
_instance: AgentRegistry | None = None
|
|
|
|
def __init__(self) -> None:
|
|
self._agents: dict[str, type[ChatAgent]] = {}
|
|
|
|
def __new__(cls) -> AgentRegistry:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._agents = {}
|
|
return cls._instance
|
|
|
|
# ── public API ───────────────────────────────────────────────────
|
|
|
|
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
"""Class decorator — registers an agent by its name."""
|
|
instance = agent_class()
|
|
name = instance.get_name()
|
|
self._agents[name] = agent_class
|
|
return agent_class
|
|
|
|
def get(self, name: str) -> ChatAgent:
|
|
"""Return a fresh instance of the named agent."""
|
|
cls = self._agents.get(name)
|
|
if cls is None:
|
|
raise KeyError(f"Agent not found: {name}")
|
|
return cls()
|
|
|
|
def list_agents(self) -> list[dict[str, str]]:
|
|
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
result: list[dict[str, str]] = []
|
|
for cls in self._agents.values():
|
|
inst = cls()
|
|
result.append(
|
|
{"name": inst.get_name(), "description": inst.get_description()}
|
|
)
|
|
return result
|
|
|
|
async def call_agent(
|
|
self, name: str, query: str, context: dict[str, Any]
|
|
) -> str:
|
|
"""Instantiate the named agent and call its ``handle`` method."""
|
|
agent = self.get(name)
|
|
return await agent.handle(query, context)
|
|
|
|
|
|
# Module-level singleton
|
|
registry = AgentRegistry()
|