Refactor LLM instantiation across agents and orchestrator
- Replaced direct instantiation of ChatOpenAI with a centralized get_llm function in CheckpointAgent, NoteAgent, ProjectAgent, and TaskAgent. - Introduced a new llm.py module to handle LLM model instantiation and API key management. - Updated settings.py to include LLM_MODEL and LLM_ROUTER_MODEL configurations. - Modified orchestrator.py to use get_router_llm for intent classification. - Updated requirements.txt to include litellm for LLM management. - Adjusted tests to mock get_llm instead of ChatOpenAI directly.
This commit is contained in:
@@ -102,21 +102,21 @@ class TestTaskAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_string(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Task created.")
|
||||
result = await TaskAgent().handle("create a task", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
||||
result = await TaskAgent().handle("list my tasks", {})
|
||||
assert result == "Here are your tasks."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_task",
|
||||
{"title": "Buy groceries", "priority": "low"},
|
||||
@@ -127,7 +127,7 @@ class TestTaskAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await TaskAgent().handle("help", {})
|
||||
assert isinstance(result, str)
|
||||
@@ -138,7 +138,7 @@ class TestTaskAgent:
|
||||
"user_profile": {"id": "u1", "tier": "pro"},
|
||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
||||
}
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
||||
result = await TaskAgent().handle("show tasks", context)
|
||||
assert isinstance(result, str)
|
||||
@@ -273,14 +273,14 @@ class TestCheckpointAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
||||
result = await CheckpointAgent().handle("list checkpoints", {})
|
||||
assert result == "No checkpoints found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_tool_call(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_checkpoint",
|
||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
||||
@@ -291,7 +291,7 @@ class TestCheckpointAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await CheckpointAgent().handle("show milestones", {})
|
||||
assert isinstance(result, str)
|
||||
@@ -397,14 +397,14 @@ class TestProjectAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
||||
result = await ProjectAgent().handle("show my projects", {})
|
||||
assert result == "Project Alpha is active."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_project",
|
||||
{"name": "Pippo"},
|
||||
@@ -415,7 +415,7 @@ class TestProjectAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await ProjectAgent().handle("archive old project", {})
|
||||
assert isinstance(result, str)
|
||||
@@ -515,14 +515,14 @@ class TestNoteAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Note created.")
|
||||
result = await NoteAgent().handle("create a note", {})
|
||||
assert result == "Note created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_note",
|
||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
||||
@@ -533,7 +533,7 @@ class TestNoteAgent:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await NoteAgent().handle("show notes", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@@ -87,21 +87,21 @@ def reg() -> AgentRegistry:
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
result = await classify_intent("add a task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
result = await classify_intent("schedule a meeting", {}, reg)
|
||||
assert result == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
||||
result = await classify_intent("do something", {}, reg)
|
||||
assert result == "task_agent"
|
||||
@@ -110,14 +110,14 @@ class TestClassifyIntent:
|
||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
||||
empty_reg = AgentRegistry()
|
||||
# No LLM should be instantiated — early return path
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
result = await classify_intent("anything", {}, empty_reg)
|
||||
mock_cls.assert_not_called()
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
||||
result = await classify_intent("create task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
@@ -154,7 +154,7 @@ class TestRouteSingle:
|
||||
class TestRoutePipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
@@ -163,7 +163,7 @@ class TestRoutePipeline:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
@@ -193,7 +193,7 @@ class TestRoutePipeline:
|
||||
|
||||
reg.register(_CapturingAgent)
|
||||
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("done")
|
||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
||||
|
||||
@@ -204,7 +204,7 @@ class TestRoutePipeline:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("single result")
|
||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
||||
assert result.response == "single result"
|
||||
@@ -218,7 +218,7 @@ class TestOrchestrate:
|
||||
async def test_direct_mode_returns_chat_response(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
@@ -226,7 +226,7 @@ class TestOrchestrate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
@@ -237,7 +237,7 @@ class TestOrchestrate:
|
||||
async def test_plan_mode_returns_execution_plan(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
@@ -247,7 +247,7 @@ class TestOrchestrate:
|
||||
async def test_plan_mode_agent_matches_classified(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
request = ChatRequest(
|
||||
message="schedule something", execution_mode="plan"
|
||||
@@ -258,7 +258,7 @@ class TestOrchestrate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
@@ -269,7 +269,7 @@ class TestOrchestrate:
|
||||
async def test_plan_mode_template_id_contains_agent_name(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
@@ -281,7 +281,7 @@ class TestOrchestrate:
|
||||
async def test_default_execution_mode_is_direct(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
# execution_mode defaults to "direct"
|
||||
request = ChatRequest(message="help me")
|
||||
@@ -295,7 +295,7 @@ class TestOrchestrate:
|
||||
class TestOrchestrateStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
@@ -305,7 +305,7 @@ class TestOrchestrateStream:
|
||||
async def test_last_chunk_is_final_json_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
@@ -319,7 +319,7 @@ class TestOrchestrateStream:
|
||||
async def test_final_frame_response_matches_agent_output(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
@@ -331,7 +331,7 @@ class TestOrchestrateStream:
|
||||
async def test_text_chunks_before_final_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(
|
||||
message="x" * 200, execution_mode="direct"
|
||||
|
||||
Reference in New Issue
Block a user