From 2c082759343ffaae7197e79273046e80829a2042 Mon Sep 17 00:00:00 2001 From: roberto Date: Sun, 8 Mar 2026 21:42:46 +0100 Subject: [PATCH] step-3: add router refactor with streaming support (orchestrator.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - orchestrate_v3(user_id, message, context): classifies intent, returns (agent_name, agent_instance) — caller drives execution - orchestrate_v3_stream(user_id, message, context): yields (agent_name, token) pairs; first yield is always (agent_name, "") as a domain-detection signal - ChatAgent.handle_stream(): default implementation yields handle() result as one chunk; subclasses override for true token-level streaming - Fix stale test_orchestrator.py assertions that expected a JSON final frame that orchestrate_stream never emitted Co-Authored-By: Claude Sonnet 4.6 --- V3_MIGRATION_PLAN.md | 2 +- app/core/agent_registry.py | 10 ++ app/core/orchestrator.py | 40 +++++- tests/test_orchestrator.py | 15 +-- tests/test_orchestrator_v3.py | 236 ++++++++++++++++++++++++++++++++++ 5 files changed, 293 insertions(+), 10 deletions(-) create mode 100644 tests/test_orchestrator_v3.py diff --git a/V3_MIGRATION_PLAN.md b/V3_MIGRATION_PLAN.md index d5da12e..090923f 100644 --- a/V3_MIGRATION_PLAN.md +++ b/V3_MIGRATION_PLAN.md @@ -119,7 +119,7 @@ pytest tests/test_orchestrator_v3.py ``` **Status**: -- [ ] Step 3 complete +- [x] Step 3 complete **Commit**: After tests pass, commit with: ``` diff --git a/app/core/agent_registry.py b/app/core/agent_registry.py index 323e4ea..9a4930d 100644 --- a/app/core/agent_registry.py +++ b/app/core/agent_registry.py @@ -45,6 +45,16 @@ class ChatAgent(BaseAgent): """Process a user query and return a text response.""" ... + async def handle_stream( + self, query: str, context: dict[str, Any] + ) -> AsyncGenerator[str, None]: + """Streaming variant of handle(). + + Default: calls handle() and yields the full response as one chunk. + Override in subclasses for true token-level streaming via _tool_loop_stream. + """ + yield await self.handle(query, context) + @abstractmethod def get_tools(self) -> list[Any]: """Return LangChain tool definitions available to this agent.""" diff --git a/app/core/orchestrator.py b/app/core/orchestrator.py index 982ef30..ca1dbc7 100644 --- a/app/core/orchestrator.py +++ b/app/core/orchestrator.py @@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator from langchain_core.messages import HumanMessage, SystemMessage -from app.core.agent_registry import AgentRegistry +from app.core.agent_registry import AgentRegistry, ChatAgent from app.core.llm import get_router_llm from app.core.agent_registry import registry as _default_registry from app.schemas import ChatRequest, ChatResponse, ExecutionPlan @@ -140,6 +140,44 @@ async def orchestrate( return _build_plan(agent_name, request.message) +async def orchestrate_v3( + user_id: str, + message: str, + context: dict[str, Any], + reg: AgentRegistry | None = None, +) -> tuple[str, ChatAgent]: + """v3 orchestration — returns (agent_name, agent_instance); caller drives execution. + + Classifies intent and instantiates the matching agent. The caller is responsible + for invoking handle(), handle_stream(), or _tool_loop_stream() as needed. + """ + if reg is None: + reg = _default_registry + agent_name = await classify_intent(message, context, reg) + return agent_name, reg.get(agent_name) + + +async def orchestrate_v3_stream( + user_id: str, + message: str, + context: dict[str, Any], + reg: AgentRegistry | None = None, +) -> AsyncGenerator[tuple[str, str], None]: + """v3 streaming orchestration — yields (agent_name, token) pairs. + + The first yield always carries the agent_name with an empty token so that + callers (e.g. PopupFormatter) can detect the routing domain before any text + tokens arrive. + """ + if reg is None: + reg = _default_registry + agent_name = await classify_intent(message, context, reg) + agent = reg.get(agent_name) + yield agent_name, "" # domain signal — no token yet + async for token in agent.handle_stream(message, context): + yield agent_name, token + + async def orchestrate_stream( request: ChatRequest, reg: AgentRegistry | None = None, diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py index 107acf8..07576d4 100644 --- a/tests/test_orchestrator.py +++ b/tests/test_orchestrator.py @@ -302,7 +302,7 @@ class TestOrchestrateStream: assert len(chunks) >= 1 @pytest.mark.asyncio - async def test_last_chunk_is_final_json_frame( + async def test_all_chunks_are_plain_text( self, reg: AgentRegistry ) -> None: with patch("app.core.orchestrator._make_llm") as mock_cls: @@ -310,13 +310,12 @@ class TestOrchestrateStream: request = ChatRequest(message="add a task", execution_mode="direct") chunks = [chunk async for chunk in orchestrate_stream(request, reg)] - last = json.loads(chunks[-1]) - assert last["done"] is True - assert "response" in last - assert "actions" in last + # orchestrate_stream yields plain text chunks only — no JSON final frame + for chunk in chunks: + assert isinstance(chunk, str) @pytest.mark.asyncio - async def test_final_frame_response_matches_agent_output( + async def test_concatenated_chunks_equal_full_response( self, reg: AgentRegistry ) -> None: with patch("app.core.orchestrator._make_llm") as mock_cls: @@ -324,8 +323,8 @@ class TestOrchestrateStream: request = ChatRequest(message="create a task", execution_mode="direct") chunks = [chunk async for chunk in orchestrate_stream(request, reg)] - final = json.loads(chunks[-1]) - assert final["response"] == "task: create a task" + full_text = "".join(chunks) + assert full_text == "task: create a task" @pytest.mark.asyncio async def test_text_chunks_before_final_frame( diff --git a/tests/test_orchestrator_v3.py b/tests/test_orchestrator_v3.py new file mode 100644 index 0000000..cf9197d --- /dev/null +++ b/tests/test_orchestrator_v3.py @@ -0,0 +1,236 @@ +"""Tests for v3 orchestrator functions (Step 3).""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any + +from app.core.agent_registry import ChatAgent, AgentRegistry +from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream + + +# ── Minimal agent for testing ───────────────────────────────────────── + + +class _FixedAgent(ChatAgent): + def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._name = name + self._tokens = tokens or ["Hello", " world"] + + def get_name(self) -> str: + return self._name + + def get_description(self) -> str: + return "Fixed agent for tests" + + def get_tools(self) -> list[Any]: + return [] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + return "".join(self._tokens) + + async def handle_stream(self, query: str, context: dict[str, Any]): + for tok in self._tokens: + yield tok + + +# ── Mock registry factory ───────────────────────────────────────────── + + +def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock: + reg = MagicMock(spec=AgentRegistry) + reg.list_agents.return_value = [{"name": agent_name, "description": "test"}] + reg.get.return_value = agent + return reg + + +# ── orchestrate_v3 ──────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_orchestrate_v3_returns_agent_name_and_instance(): + agent = _FixedAgent("task_agent") + reg = _make_registry("task_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): + name, inst = await orchestrate_v3( + user_id="u-1", message="fix a bug", context={}, reg=reg + ) + + assert name == "task_agent" + assert inst is agent + + +@pytest.mark.asyncio +async def test_orchestrate_v3_classify_called_with_message_and_context(): + agent = _FixedAgent("note_agent") + reg = _make_registry("note_agent", agent) + ctx = {"some": "context"} + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify: + await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg) + + mock_classify.assert_awaited_once() + call_args = mock_classify.call_args + assert call_args[0][0] == "take a note" + assert call_args[0][1] == ctx + + +@pytest.mark.asyncio +async def test_orchestrate_v3_uses_default_registry_when_none(): + agent = _FixedAgent("task_agent") + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \ + patch("app.core.orchestrator._default_registry") as mock_reg: + mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}] + mock_reg.get.return_value = agent + name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={}) + + assert name == "task_agent" + assert inst is agent + + +@pytest.mark.asyncio +async def test_orchestrate_v3_get_called_with_agent_name(): + agent = _FixedAgent("checkpoint_agent") + reg = _make_registry("checkpoint_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="checkpoint_agent")): + await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg) + + reg.get.assert_called_once_with("checkpoint_agent") + + +# ── orchestrate_v3_stream ───────────────────────────────────────────── + + +async def _collect(gen) -> list[tuple[str, str]]: + results: list[tuple[str, str]] = [] + async for item in gen: + results.append(item) + return results + + +@pytest.mark.asyncio +async def test_orchestrate_v3_stream_first_yield_is_domain_signal(): + agent = _FixedAgent("task_agent", tokens=["token1"]) + reg = _make_registry("task_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): + gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) + results = await _collect(gen) + + # First item must be (agent_name, "") — domain signal + assert results[0] == ("task_agent", "") + + +@pytest.mark.asyncio +async def test_orchestrate_v3_stream_yields_agent_name_with_tokens(): + agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"]) + reg = _make_registry("task_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): + gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) + results = await _collect(gen) + + # All items are (agent_name, token) pairs + assert all(name == "task_agent" for name, _ in results) + tokens = [tok for _, tok in results] + assert tokens[0] == "" # domain signal + assert tokens[1:] == ["Hello", " ", "world"] + + +@pytest.mark.asyncio +async def test_orchestrate_v3_stream_different_agent(): + agent = _FixedAgent("note_agent", tokens=["note"]) + reg = _make_registry("note_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")): + gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg) + results = await _collect(gen) + + assert results[0] == ("note_agent", "") + assert ("note_agent", "note") in results + + +@pytest.mark.asyncio +async def test_orchestrate_v3_stream_uses_default_registry_when_none(): + agent = _FixedAgent("task_agent", tokens=["x"]) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \ + patch("app.core.orchestrator._default_registry") as mock_reg: + mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}] + mock_reg.get.return_value = agent + gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}) + results = await _collect(gen) + + assert results[0][0] == "task_agent" + + +@pytest.mark.asyncio +async def test_orchestrate_v3_stream_empty_token_list(): + """Agent with no tokens still emits the domain signal.""" + + class _EmptyAgent(_FixedAgent): + async def handle_stream(self, query: str, context: dict[str, Any]): + return + yield # makes it a generator + + agent = _EmptyAgent("task_agent", tokens=[]) + reg = _make_registry("task_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): + gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) + results = await _collect(gen) + + assert results == [("task_agent", "")] # only domain signal + + +@pytest.mark.asyncio +async def test_orchestrate_v3_stream_full_text_correct(): + """Concatenating all non-domain tokens reconstructs the full response.""" + tokens = ["The", " ", "task", " ", "is", " ", "done."] + agent = _FixedAgent("task_agent", tokens=tokens) + reg = _make_registry("task_agent", agent) + + with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): + gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) + results = await _collect(gen) + + text = "".join(tok for _, tok in results[1:]) # skip domain signal + assert text == "The task is done." + + +# ── handle_stream default implementation ───────────────────────────── + + +@pytest.mark.asyncio +async def test_handle_stream_default_yields_full_response(): + """Default handle_stream yields handle() result as a single chunk.""" + + class _SimpleAgent(ChatAgent): + def get_name(self) -> str: + return "_simple" + + def get_description(self) -> str: + return "" + + def get_tools(self) -> list[Any]: + return [] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + return "simple response" + + agent = _SimpleAgent() + tokens = [tok async for tok in agent.handle_stream("q", {})] + assert tokens == ["simple response"] + + +@pytest.mark.asyncio +async def test_handle_stream_override_used_by_stream(): + """_FixedAgent.handle_stream override yields individual tokens.""" + agent = _FixedAgent("t", tokens=["a", "b", "c"]) + tokens = [tok async for tok in agent.handle_stream("q", {})] + assert tokens == ["a", "b", "c"]