Refactor LLM instantiation across agents and orchestrator

- Replaced direct instantiation of ChatOpenAI with a centralized get_llm function in CheckpointAgent, NoteAgent, ProjectAgent, and TaskAgent.
- Introduced a new llm.py module to handle LLM model instantiation and API key management.
- Updated settings.py to include LLM_MODEL and LLM_ROUTER_MODEL configurations.
- Modified orchestrator.py to use get_router_llm for intent classification.
- Updated requirements.txt to include litellm for LLM management.
- Adjusted tests to mock get_llm instead of ChatOpenAI directly.
This commit is contained in:
2026-03-03 15:46:44 +01:00
parent 480e7ac5bd
commit 8bfce9da00
11 changed files with 830 additions and 50 deletions

View File

@@ -7,10 +7,9 @@ from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
@@ -112,7 +111,7 @@ class CheckpointAgent(ChatAgent):
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(

View File

@@ -7,10 +7,9 @@ from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
@@ -113,7 +112,7 @@ class NoteAgent(ChatAgent):
return [list_notes, get_note, create_note, update_note, delete_note]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(

View File

@@ -7,10 +7,9 @@ from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
@@ -148,7 +147,7 @@ class ProjectAgent(ChatAgent):
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(

View File

@@ -7,10 +7,9 @@ from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
@@ -219,7 +218,7 @@ class TaskAgent(ChatAgent):
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(

View File

@@ -24,6 +24,9 @@ class Settings(BaseSettings):
OPENAI_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
ENV: Literal["dev", "prod"] = "dev"

68
app/core/llm.py Normal file
View File

@@ -0,0 +1,68 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
* Anthropic: ``anthropic/claude-3.5-sonnet``
* Google: ``gemini/gemini-pro``
* Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
— no code changes required.
"""
from __future__ import annotations
from langchain_openai import ChatOpenAI
from litellm import get_supported_openai_params # noqa: F401 validates install
from app.config.settings import settings
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""
if model.startswith("anthropic/"):
return getattr(settings, "ANTHROPIC_API_KEY", None) or None
if model.startswith("gemini/") or model.startswith("google/"):
return getattr(settings, "GOOGLE_API_KEY", None) or None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI:
"""Return a LangChain chat model backed by LiteLLM.
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
``openai`` client transparently when the model string contains a provider
prefix (``anthropic/…``, ``gemini/…``, etc.).
Parameters
----------
model:
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
temperature:
Sampling temperature. ``0`` = deterministic.
"""
model = model or settings.LLM_MODEL
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI:
"""Return the lighter model used for intent classification / routing."""
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)

View File

@@ -6,10 +6,9 @@ import json
from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import AgentRegistry
from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
@@ -29,8 +28,8 @@ _SYNTHESIZE_HUMAN = (
)
def _make_llm(model: str = "gpt-4o-mini") -> ChatOpenAI:
return ChatOpenAI(model=model, temperature=0, api_key=settings.OPENAI_API_KEY)
def _make_llm():
return get_router_llm()
async def classify_intent(