step-3: add router refactor with streaming support (orchestrator.py)
- orchestrate_v3(user_id, message, context): classifies intent, returns (agent_name, agent_instance) — caller drives execution - orchestrate_v3_stream(user_id, message, context): yields (agent_name, token) pairs; first yield is always (agent_name, "") as a domain-detection signal - ChatAgent.handle_stream(): default implementation yields handle() result as one chunk; subclasses override for true token-level streaming - Fix stale test_orchestrator.py assertions that expected a JSON final frame that orchestrate_stream never emitted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -119,7 +119,7 @@ pytest tests/test_orchestrator_v3.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Status**:
|
**Status**:
|
||||||
- [ ] Step 3 complete
|
- [x] Step 3 complete
|
||||||
|
|
||||||
**Commit**: After tests pass, commit with:
|
**Commit**: After tests pass, commit with:
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -45,6 +45,16 @@ class ChatAgent(BaseAgent):
|
|||||||
"""Process a user query and return a text response."""
|
"""Process a user query and return a text response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def handle_stream(
|
||||||
|
self, query: str, context: dict[str, Any]
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Streaming variant of handle().
|
||||||
|
|
||||||
|
Default: calls handle() and yields the full response as one chunk.
|
||||||
|
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
||||||
|
"""
|
||||||
|
yield await self.handle(query, context)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_tools(self) -> list[Any]:
|
def get_tools(self) -> list[Any]:
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
"""Return LangChain tool definitions available to this agent."""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator
|
|||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||||
from app.core.llm import get_router_llm
|
from app.core.llm import get_router_llm
|
||||||
from app.core.agent_registry import registry as _default_registry
|
from app.core.agent_registry import registry as _default_registry
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||||
@@ -140,6 +140,44 @@ async def orchestrate(
|
|||||||
return _build_plan(agent_name, request.message)
|
return _build_plan(agent_name, request.message)
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate_v3(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> tuple[str, ChatAgent]:
|
||||||
|
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
||||||
|
|
||||||
|
Classifies intent and instantiates the matching agent. The caller is responsible
|
||||||
|
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
agent_name = await classify_intent(message, context, reg)
|
||||||
|
return agent_name, reg.get(agent_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate_v3_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, str], None]:
|
||||||
|
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
||||||
|
|
||||||
|
The first yield always carries the agent_name with an empty token so that
|
||||||
|
callers (e.g. PopupFormatter) can detect the routing domain before any text
|
||||||
|
tokens arrive.
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
agent_name = await classify_intent(message, context, reg)
|
||||||
|
agent = reg.get(agent_name)
|
||||||
|
yield agent_name, "" # domain signal — no token yet
|
||||||
|
async for token in agent.handle_stream(message, context):
|
||||||
|
yield agent_name, token
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
async def orchestrate_stream(
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
reg: AgentRegistry | None = None,
|
reg: AgentRegistry | None = None,
|
||||||
|
|||||||
@@ -302,7 +302,7 @@ class TestOrchestrateStream:
|
|||||||
assert len(chunks) >= 1
|
assert len(chunks) >= 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_last_chunk_is_final_json_frame(
|
async def test_all_chunks_are_plain_text(
|
||||||
self, reg: AgentRegistry
|
self, reg: AgentRegistry
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
@@ -310,13 +310,12 @@ class TestOrchestrateStream:
|
|||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
last = json.loads(chunks[-1])
|
# orchestrate_stream yields plain text chunks only — no JSON final frame
|
||||||
assert last["done"] is True
|
for chunk in chunks:
|
||||||
assert "response" in last
|
assert isinstance(chunk, str)
|
||||||
assert "actions" in last
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_final_frame_response_matches_agent_output(
|
async def test_concatenated_chunks_equal_full_response(
|
||||||
self, reg: AgentRegistry
|
self, reg: AgentRegistry
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
@@ -324,8 +323,8 @@ class TestOrchestrateStream:
|
|||||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
final = json.loads(chunks[-1])
|
full_text = "".join(chunks)
|
||||||
assert final["response"] == "task: create a task"
|
assert full_text == "task: create a task"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_text_chunks_before_final_frame(
|
async def test_text_chunks_before_final_frame(
|
||||||
|
|||||||
236
tests/test_orchestrator_v3.py
Normal file
236
tests/test_orchestrator_v3.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Tests for v3 orchestrator functions (Step 3)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, AgentRegistry
|
||||||
|
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
|
||||||
|
|
||||||
|
|
||||||
|
# ── Minimal agent for testing ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FixedAgent(ChatAgent):
|
||||||
|
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._name = name
|
||||||
|
self._tokens = tokens or ["Hello", " world"]
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Fixed agent for tests"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return "".join(self._tokens)
|
||||||
|
|
||||||
|
async def handle_stream(self, query: str, context: dict[str, Any]):
|
||||||
|
for tok in self._tokens:
|
||||||
|
yield tok
|
||||||
|
|
||||||
|
|
||||||
|
# ── Mock registry factory ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
|
||||||
|
reg = MagicMock(spec=AgentRegistry)
|
||||||
|
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
|
||||||
|
reg.get.return_value = agent
|
||||||
|
return reg
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate_v3 ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_returns_agent_name_and_instance():
|
||||||
|
agent = _FixedAgent("task_agent")
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
name, inst = await orchestrate_v3(
|
||||||
|
user_id="u-1", message="fix a bug", context={}, reg=reg
|
||||||
|
)
|
||||||
|
|
||||||
|
assert name == "task_agent"
|
||||||
|
assert inst is agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_classify_called_with_message_and_context():
|
||||||
|
agent = _FixedAgent("note_agent")
|
||||||
|
reg = _make_registry("note_agent", agent)
|
||||||
|
ctx = {"some": "context"}
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
|
||||||
|
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
|
||||||
|
|
||||||
|
mock_classify.assert_awaited_once()
|
||||||
|
call_args = mock_classify.call_args
|
||||||
|
assert call_args[0][0] == "take a note"
|
||||||
|
assert call_args[0][1] == ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_uses_default_registry_when_none():
|
||||||
|
agent = _FixedAgent("task_agent")
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
||||||
|
patch("app.core.orchestrator._default_registry") as mock_reg:
|
||||||
|
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
||||||
|
mock_reg.get.return_value = agent
|
||||||
|
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
|
||||||
|
|
||||||
|
assert name == "task_agent"
|
||||||
|
assert inst is agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_get_called_with_agent_name():
|
||||||
|
agent = _FixedAgent("checkpoint_agent")
|
||||||
|
reg = _make_registry("checkpoint_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="checkpoint_agent")):
|
||||||
|
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
|
||||||
|
|
||||||
|
reg.get.assert_called_once_with("checkpoint_agent")
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate_v3_stream ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _collect(gen) -> list[tuple[str, str]]:
|
||||||
|
results: list[tuple[str, str]] = []
|
||||||
|
async for item in gen:
|
||||||
|
results.append(item)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
|
||||||
|
agent = _FixedAgent("task_agent", tokens=["token1"])
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
# First item must be (agent_name, "") — domain signal
|
||||||
|
assert results[0] == ("task_agent", "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
|
||||||
|
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
# All items are (agent_name, token) pairs
|
||||||
|
assert all(name == "task_agent" for name, _ in results)
|
||||||
|
tokens = [tok for _, tok in results]
|
||||||
|
assert tokens[0] == "" # domain signal
|
||||||
|
assert tokens[1:] == ["Hello", " ", "world"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_different_agent():
|
||||||
|
agent = _FixedAgent("note_agent", tokens=["note"])
|
||||||
|
reg = _make_registry("note_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
assert results[0] == ("note_agent", "")
|
||||||
|
assert ("note_agent", "note") in results
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
|
||||||
|
agent = _FixedAgent("task_agent", tokens=["x"])
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
||||||
|
patch("app.core.orchestrator._default_registry") as mock_reg:
|
||||||
|
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
||||||
|
mock_reg.get.return_value = agent
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
assert results[0][0] == "task_agent"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_empty_token_list():
|
||||||
|
"""Agent with no tokens still emits the domain signal."""
|
||||||
|
|
||||||
|
class _EmptyAgent(_FixedAgent):
|
||||||
|
async def handle_stream(self, query: str, context: dict[str, Any]):
|
||||||
|
return
|
||||||
|
yield # makes it a generator
|
||||||
|
|
||||||
|
agent = _EmptyAgent("task_agent", tokens=[])
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
assert results == [("task_agent", "")] # only domain signal
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestrate_v3_stream_full_text_correct():
|
||||||
|
"""Concatenating all non-domain tokens reconstructs the full response."""
|
||||||
|
tokens = ["The", " ", "task", " ", "is", " ", "done."]
|
||||||
|
agent = _FixedAgent("task_agent", tokens=tokens)
|
||||||
|
reg = _make_registry("task_agent", agent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||||
|
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||||
|
results = await _collect(gen)
|
||||||
|
|
||||||
|
text = "".join(tok for _, tok in results[1:]) # skip domain signal
|
||||||
|
assert text == "The task is done."
|
||||||
|
|
||||||
|
|
||||||
|
# ── handle_stream default implementation ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_stream_default_yields_full_response():
|
||||||
|
"""Default handle_stream yields handle() result as a single chunk."""
|
||||||
|
|
||||||
|
class _SimpleAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "_simple"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return "simple response"
|
||||||
|
|
||||||
|
agent = _SimpleAgent()
|
||||||
|
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
||||||
|
assert tokens == ["simple response"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_stream_override_used_by_stream():
|
||||||
|
"""_FixedAgent.handle_stream override yields individual tokens."""
|
||||||
|
agent = _FixedAgent("t", tokens=["a", "b", "c"])
|
||||||
|
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
||||||
|
assert tokens == ["a", "b", "c"]
|
||||||
Reference in New Issue
Block a user