feat: enhance agent configuration and model management with per-agent overrides

This commit is contained in:
Roberto Musso
2026-04-10 08:45:14 +02:00
parent 7253f6fe72
commit 3cf067faea
9 changed files with 106 additions and 22 deletions

View File

@@ -43,10 +43,9 @@ from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.config.settings import settings
from app.core.device_manager import DeviceConnectionManager
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm
from app.core.llm import get_agent_llm, model_for_agent
from app.core.preprocessors import detect_content_type, preprocess
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
from app.db import async_session
@@ -74,13 +73,13 @@ _MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ─────────────────────────────────────────────
# NOTE: "projects" is intentionally excluded — project creation/assignment is
# handled in code by the runner, never delegated to the Step 2 LLM.
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
"timelineEvents": TIMELINE_TOOLS,
"projects": PROJECT_TOOLS,
}
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
@@ -238,7 +237,7 @@ async def _run_agent_with_tools(
run is appended to it (used by the caller to count ``create_*`` calls).
"""
lf = get_langfuse()
llm = get_llm()
llm = get_agent_llm(agent_name)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
@@ -264,7 +263,7 @@ async def _run_agent_with_tools(
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=settings.LLM_MODEL,
model=model_for_agent(agent_name),
prompt=langfuse_prompt,
input=messages,
)
@@ -696,6 +695,12 @@ async def run_local_agent(
)
items_created += file_created
# Refresh project list when a project was created so
# subsequent files see it in the prompt context.
if "create_project" in file_tool_calls:
projects = await _fetch_projects()
projects_block = _format_projects(projects)
logger.info(
"agent_runner: run=%s file=%r created=%d result=%s",
run_id, file_path, file_created, result_text[:200],

View File

@@ -17,8 +17,7 @@ from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm
from app.config.settings import settings
from app.core.llm import get_agent_llm, model_for_agent
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
@@ -537,7 +536,7 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
}
try:
llm = get_llm()
llm = get_agent_llm("classifier")
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
@@ -555,7 +554,7 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=settings.LLM_MODEL,
model=model_for_agent("classifier"),
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
@@ -592,7 +591,7 @@ async def _run_single_agent(
) -> str:
trace_id = _trace_id_from_context(context)
lf = get_langfuse()
llm = get_llm()
llm = get_agent_llm(agent_name)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
@@ -628,7 +627,7 @@ async def _run_single_agent(
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=settings.LLM_MODEL,
model=model_for_agent(agent_name),
prompt=langfuse_prompt,
input=messages,
)
@@ -715,7 +714,7 @@ async def _run_single_agent_stream(
) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context)
lf = get_langfuse()
llm = get_llm()
llm = get_agent_llm(agent_name)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
@@ -753,7 +752,7 @@ async def _run_single_agent_stream(
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=settings.LLM_MODEL,
model=model_for_agent(agent_name),
prompt=langfuse_prompt,
input=messages,
)

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import os
import warnings
from collections.abc import Callable
from openai import AsyncOpenAI
import litellm
@@ -95,6 +96,35 @@ def get_llm(
)
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
}
def model_for_agent(agent_name: str) -> str:
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
def get_agent_llm(
agent_name: str,
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
per-agent override is left empty in ``.env``.
"""
model = model_for_agent(agent_name)
return get_llm(model=model, temperature=temperature)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.