feat: enhance agent configuration and model management with per-agent overrides
This commit is contained in:
@@ -43,10 +43,9 @@ from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.config.settings import settings
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.core.preprocessors import detect_content_type, preprocess
|
||||
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
||||
from app.db import async_session
|
||||
@@ -74,13 +73,13 @@ _MAX_PROCESSING_STEPS: int = 12
|
||||
_MAX_SCAN_DEPTH: int = 5
|
||||
|
||||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||
# NOTE: "projects" is intentionally excluded — project creation/assignment is
|
||||
# handled in code by the runner, never delegated to the Step 2 LLM.
|
||||
|
||||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||
"tasks": TASK_TOOLS,
|
||||
"notes": NOTE_TOOLS,
|
||||
"timelines": TIMELINE_TOOLS,
|
||||
"timelineEvents": TIMELINE_TOOLS,
|
||||
"projects": PROJECT_TOOLS,
|
||||
}
|
||||
|
||||
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
|
||||
@@ -238,7 +237,7 @@ async def _run_agent_with_tools(
|
||||
run is appended to it (used by the caller to count ``create_*`` calls).
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm(agent_name)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
@@ -264,7 +263,7 @@ async def _run_agent_with_tools(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name=f"{agent_name}-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent(agent_name),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
@@ -696,6 +695,12 @@ async def run_local_agent(
|
||||
)
|
||||
items_created += file_created
|
||||
|
||||
# Refresh project list when a project was created so
|
||||
# subsequent files see it in the prompt context.
|
||||
if "create_project" in file_tool_calls:
|
||||
projects = await _fetch_projects()
|
||||
projects_block = _format_projects(projects)
|
||||
|
||||
logger.info(
|
||||
"agent_runner: run=%s file=%r created=%d result=%s",
|
||||
run_id, file_path, file_created, result_text[:200],
|
||||
|
||||
@@ -17,8 +17,7 @@ from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.config.settings import settings
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||
from app.db import async_session
|
||||
@@ -537,7 +536,7 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
||||
}
|
||||
|
||||
try:
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm("classifier")
|
||||
classifier_messages = [
|
||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
|
||||
HumanMessage(
|
||||
@@ -555,7 +554,7 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="floating-classifier",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent("classifier"),
|
||||
prompt=classifier_prompt_obj,
|
||||
input=classifier_messages,
|
||||
) as gen:
|
||||
@@ -592,7 +591,7 @@ async def _run_single_agent(
|
||||
) -> str:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
lf = get_langfuse()
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm(agent_name)
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
@@ -628,7 +627,7 @@ async def _run_single_agent(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name=f"{agent_name}-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent(agent_name),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
@@ -715,7 +714,7 @@ async def _run_single_agent_stream(
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
lf = get_langfuse()
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm(agent_name)
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
@@ -753,7 +752,7 @@ async def _run_single_agent_stream(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name=f"{agent_name}-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent(agent_name),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
@@ -95,6 +96,35 @@ def get_llm(
|
||||
)
|
||||
|
||||
|
||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
}
|
||||
|
||||
|
||||
def model_for_agent(agent_name: str) -> str:
|
||||
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||
|
||||
|
||||
def get_agent_llm(
|
||||
agent_name: str,
|
||||
*,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||
|
||||
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||
per-agent override is left empty in ``.env``.
|
||||
"""
|
||||
model = model_for_agent(agent_name)
|
||||
return get_llm(model=model, temperature=temperature)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
"""Return an embedding vector for *text*.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user