"""Tests for v3 orchestrator functions (Step 3).""" from __future__ import annotations import pytest from unittest.mock import AsyncMock, MagicMock, patch from typing import Any from app.core.agent_registry import ChatAgent, AgentRegistry from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream # ── Minimal agent for testing ───────────────────────────────────────── class _FixedAgent(ChatAgent): def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None: super().__init__(**kwargs) self._name = name self._tokens = tokens or ["Hello", " world"] def get_name(self) -> str: return self._name def get_description(self) -> str: return "Fixed agent for tests" def get_tools(self) -> list[Any]: return [] async def handle(self, query: str, context: dict[str, Any]) -> str: return "".join(self._tokens) async def handle_stream(self, query: str, context: dict[str, Any]): for tok in self._tokens: yield tok # ── Mock registry factory ───────────────────────────────────────────── def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock: reg = MagicMock(spec=AgentRegistry) reg.list_agents.return_value = [{"name": agent_name, "description": "test"}] reg.get.return_value = agent return reg # ── orchestrate_v3 ──────────────────────────────────────────────────── @pytest.mark.asyncio async def test_orchestrate_v3_returns_agent_name_and_instance(): agent = _FixedAgent("task_agent") reg = _make_registry("task_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): name, inst = await orchestrate_v3( user_id="u-1", message="fix a bug", context={}, reg=reg ) assert name == "task_agent" assert inst is agent @pytest.mark.asyncio async def test_orchestrate_v3_classify_called_with_message_and_context(): agent = _FixedAgent("note_agent") reg = _make_registry("note_agent", agent) ctx = {"some": "context"} with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify: await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg) mock_classify.assert_awaited_once() call_args = mock_classify.call_args assert call_args[0][0] == "take a note" assert call_args[0][1] == ctx @pytest.mark.asyncio async def test_orchestrate_v3_uses_default_registry_when_none(): agent = _FixedAgent("task_agent") with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \ patch("app.core.orchestrator._default_registry") as mock_reg: mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}] mock_reg.get.return_value = agent name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={}) assert name == "task_agent" assert inst is agent @pytest.mark.asyncio async def test_orchestrate_v3_get_called_with_agent_name(): agent = _FixedAgent("timeline_agent") reg = _make_registry("timeline_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")): await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg) reg.get.assert_called_once_with("timeline_agent") # ── orchestrate_v3_stream ───────────────────────────────────────────── async def _collect(gen) -> list[tuple[str, str]]: results: list[tuple[str, str]] = [] async for item in gen: results.append(item) return results @pytest.mark.asyncio async def test_orchestrate_v3_stream_first_yield_is_domain_signal(): agent = _FixedAgent("task_agent", tokens=["token1"]) reg = _make_registry("task_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) results = await _collect(gen) # First item must be (agent_name, "") — domain signal assert results[0] == ("task_agent", "") @pytest.mark.asyncio async def test_orchestrate_v3_stream_yields_agent_name_with_tokens(): agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"]) reg = _make_registry("task_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) results = await _collect(gen) # All items are (agent_name, token) pairs assert all(name == "task_agent" for name, _ in results) tokens = [tok for _, tok in results] assert tokens[0] == "" # domain signal assert tokens[1:] == ["Hello", " ", "world"] @pytest.mark.asyncio async def test_orchestrate_v3_stream_different_agent(): agent = _FixedAgent("note_agent", tokens=["note"]) reg = _make_registry("note_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")): gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg) results = await _collect(gen) assert results[0] == ("note_agent", "") assert ("note_agent", "note") in results @pytest.mark.asyncio async def test_orchestrate_v3_stream_uses_default_registry_when_none(): agent = _FixedAgent("task_agent", tokens=["x"]) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \ patch("app.core.orchestrator._default_registry") as mock_reg: mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}] mock_reg.get.return_value = agent gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}) results = await _collect(gen) assert results[0][0] == "task_agent" @pytest.mark.asyncio async def test_orchestrate_v3_stream_empty_token_list(): """Agent with no tokens still emits the domain signal.""" class _EmptyAgent(_FixedAgent): async def handle_stream(self, query: str, context: dict[str, Any]): return yield # makes it a generator agent = _EmptyAgent("task_agent", tokens=[]) reg = _make_registry("task_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) results = await _collect(gen) assert results == [("task_agent", "")] # only domain signal @pytest.mark.asyncio async def test_orchestrate_v3_stream_full_text_correct(): """Concatenating all non-domain tokens reconstructs the full response.""" tokens = ["The", " ", "task", " ", "is", " ", "done."] agent = _FixedAgent("task_agent", tokens=tokens) reg = _make_registry("task_agent", agent) with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) results = await _collect(gen) text = "".join(tok for _, tok in results[1:]) # skip domain signal assert text == "The task is done." # ── handle_stream default implementation ───────────────────────────── @pytest.mark.asyncio async def test_handle_stream_default_yields_full_response(): """Default handle_stream yields handle() result as a single chunk.""" class _SimpleAgent(ChatAgent): def get_name(self) -> str: return "_simple" def get_description(self) -> str: return "" def get_tools(self) -> list[Any]: return [] async def handle(self, query: str, context: dict[str, Any]) -> str: return "simple response" agent = _SimpleAgent() tokens = [tok async for tok in agent.handle_stream("q", {})] assert tokens == ["simple response"] @pytest.mark.asyncio async def test_handle_stream_override_used_by_stream(): """_FixedAgent.handle_stream override yields individual tokens.""" agent = _FixedAgent("t", tokens=["a", "b", "c"]) tokens = [tok async for tok in agent.handle_stream("q", {})] assert tokens == ["a", "b", "c"]