step 4 complete: intelligent routing with single-agent and pipeline modes
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
348
tests/test_orchestrator.py
Normal file
348
tests/test_orchestrator.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Integration tests for the orchestrator module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
from app.core.orchestrator import (
|
||||
classify_intent,
|
||||
orchestrate,
|
||||
orchestrate_stream,
|
||||
route_pipeline,
|
||||
route_single,
|
||||
)
|
||||
from app.schemas import ChatContext, ChatRequest, ChatResponse, ExecutionPlan
|
||||
|
||||
|
||||
# ── Stub agents ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _TaskAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "task_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages tasks: create, update, list, suggest"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"task: {query}"
|
||||
|
||||
|
||||
class _CalendarAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "calendar_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Calendar management: events, conflicts, scheduling"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"calendar: {query}"
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_llm(response_text: str) -> MagicMock:
|
||||
"""Return a mock LLM that always produces *response_text*."""
|
||||
msg = MagicMock()
|
||||
msg.content = response_text
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=msg)
|
||||
return llm
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_registry():
|
||||
"""Reset the AgentRegistry singleton between tests."""
|
||||
AgentRegistry._instance = None
|
||||
yield
|
||||
AgentRegistry._instance = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def reg() -> AgentRegistry:
|
||||
r = AgentRegistry()
|
||||
r.register(_TaskAgent)
|
||||
r.register(_CalendarAgent)
|
||||
return r
|
||||
|
||||
|
||||
# ── classify_intent ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
result = await classify_intent("add a task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
result = await classify_intent("schedule a meeting", {}, reg)
|
||||
assert result == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
||||
result = await classify_intent("do something", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
||||
empty_reg = AgentRegistry()
|
||||
# No LLM should be instantiated — early return path
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
result = await classify_intent("anything", {}, empty_reg)
|
||||
mock_cls.assert_not_called()
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
||||
result = await classify_intent("create task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
|
||||
# ── route_single ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRouteSingle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "create a task", {}, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "create a task", {}, reg)
|
||||
assert result.response == "task: create a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await route_single("nonexistent", "hello", {}, reg)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "hi", {}, reg)
|
||||
assert result.actions == []
|
||||
|
||||
|
||||
# ── route_pipeline ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRoutePipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
)
|
||||
assert result.response == "synthesized result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_previous_results_to_subsequent_agents(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
"""Each agent after the first should receive prior outputs in context."""
|
||||
received_contexts: list[dict[str, Any]] = []
|
||||
|
||||
class _CapturingAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "capture"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "captures context for testing"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
received_contexts.append(dict(context))
|
||||
return "captured"
|
||||
|
||||
reg.register(_CapturingAgent)
|
||||
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("done")
|
||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
||||
|
||||
# The second agent (capture) must have received previous results
|
||||
assert len(received_contexts) == 1
|
||||
assert "previous_results" in received_contexts[0]
|
||||
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("single result")
|
||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
||||
assert result.response == "single result"
|
||||
|
||||
|
||||
# ── orchestrate ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestrate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_returns_chat_response(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
assert result.response == "task: add a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_returns_execution_plan(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_agent_matches_classified(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
request = ChatRequest(
|
||||
message="schedule something", execution_mode="plan"
|
||||
)
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert result.agent == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert len(result.steps) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_template_id_contains_agent_name(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert result.steps[0].prompt_template is not None
|
||||
assert "task_agent" in result.steps[0].prompt_template
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_execution_mode_is_direct(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
# execution_mode defaults to "direct"
|
||||
request = ChatRequest(message="help me")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
|
||||
# ── orchestrate_stream ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestrateStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
assert len(chunks) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_chunk_is_final_json_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
last = json.loads(chunks[-1])
|
||||
assert last["done"] is True
|
||||
assert "response" in last
|
||||
assert "actions" in last
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_frame_response_matches_agent_output(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
final = json.loads(chunks[-1])
|
||||
assert final["response"] == "task: create a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chunks_before_final_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(
|
||||
message="x" * 200, execution_mode="direct"
|
||||
) # long enough to produce multiple chunks
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
# All but the last chunk should be plain text (not valid final JSON)
|
||||
non_final = chunks[:-1]
|
||||
for chunk in non_final:
|
||||
try:
|
||||
parsed = json.loads(chunk)
|
||||
assert parsed.get("done") is not True
|
||||
except json.JSONDecodeError:
|
||||
pass # plain text chunk — expected
|
||||
Reference in New Issue
Block a user