Files
api/tests/test_orchestrator_v3.py

237 lines
8.6 KiB
Python

"""Tests for v3 orchestrator functions (Step 3)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from app.core.agent_registry import ChatAgent, AgentRegistry
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
# ── Minimal agent for testing ─────────────────────────────────────────
class _FixedAgent(ChatAgent):
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._name = name
self._tokens = tokens or ["Hello", " world"]
def get_name(self) -> str:
return self._name
def get_description(self) -> str:
return "Fixed agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "".join(self._tokens)
async def handle_stream(self, query: str, context: dict[str, Any]):
for tok in self._tokens:
yield tok
# ── Mock registry factory ─────────────────────────────────────────────
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
reg = MagicMock(spec=AgentRegistry)
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
reg.get.return_value = agent
return reg
# ── orchestrate_v3 ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_orchestrate_v3_returns_agent_name_and_instance():
agent = _FixedAgent("task_agent")
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
name, inst = await orchestrate_v3(
user_id="u-1", message="fix a bug", context={}, reg=reg
)
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_classify_called_with_message_and_context():
agent = _FixedAgent("note_agent")
reg = _make_registry("note_agent", agent)
ctx = {"some": "context"}
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
mock_classify.assert_awaited_once()
call_args = mock_classify.call_args
assert call_args[0][0] == "take a note"
assert call_args[0][1] == ctx
@pytest.mark.asyncio
async def test_orchestrate_v3_uses_default_registry_when_none():
agent = _FixedAgent("task_agent")
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_get_called_with_agent_name():
agent = _FixedAgent("timeline_agent")
reg = _make_registry("timeline_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")):
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
reg.get.assert_called_once_with("timeline_agent")
# ── orchestrate_v3_stream ─────────────────────────────────────────────
async def _collect(gen) -> list[tuple[str, str]]:
results: list[tuple[str, str]] = []
async for item in gen:
results.append(item)
return results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
agent = _FixedAgent("task_agent", tokens=["token1"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# First item must be (agent_name, "") — domain signal
assert results[0] == ("task_agent", "")
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# All items are (agent_name, token) pairs
assert all(name == "task_agent" for name, _ in results)
tokens = [tok for _, tok in results]
assert tokens[0] == "" # domain signal
assert tokens[1:] == ["Hello", " ", "world"]
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_different_agent():
agent = _FixedAgent("note_agent", tokens=["note"])
reg = _make_registry("note_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
results = await _collect(gen)
assert results[0] == ("note_agent", "")
assert ("note_agent", "note") in results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
agent = _FixedAgent("task_agent", tokens=["x"])
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
results = await _collect(gen)
assert results[0][0] == "task_agent"
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_empty_token_list():
"""Agent with no tokens still emits the domain signal."""
class _EmptyAgent(_FixedAgent):
async def handle_stream(self, query: str, context: dict[str, Any]):
return
yield # makes it a generator
agent = _EmptyAgent("task_agent", tokens=[])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
assert results == [("task_agent", "")] # only domain signal
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_full_text_correct():
"""Concatenating all non-domain tokens reconstructs the full response."""
tokens = ["The", " ", "task", " ", "is", " ", "done."]
agent = _FixedAgent("task_agent", tokens=tokens)
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
text = "".join(tok for _, tok in results[1:]) # skip domain signal
assert text == "The task is done."
# ── handle_stream default implementation ─────────────────────────────
@pytest.mark.asyncio
async def test_handle_stream_default_yields_full_response():
"""Default handle_stream yields handle() result as a single chunk."""
class _SimpleAgent(ChatAgent):
def get_name(self) -> str:
return "_simple"
def get_description(self) -> str:
return ""
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "simple response"
agent = _SimpleAgent()
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["simple response"]
@pytest.mark.asyncio
async def test_handle_stream_override_used_by_stream():
"""_FixedAgent.handle_stream override yields individual tokens."""
agent = _FixedAgent("t", tokens=["a", "b", "c"])
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["a", "b", "c"]