Files
api/app/core/orchestrator.py
2026-03-10 09:11:24 +01:00

211 lines
7.0 KiB
Python

"""Orchestrator — LLM-based intent router and agent pipeline."""
from __future__ import annotations
import json
from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
_FALLBACK_AGENT = "task_agent"
_CLASSIFY_SYSTEM = (
"You are an intent classifier. Given the user message and context, decide "
"which agent to route to.\n"
"Available agents: {agents}\n"
"Respond with just the agent name, nothing else."
)
_SYNTHESIZE_HUMAN = (
"Combine the following agent results into one coherent response.\n\n"
"Agent results:\n{results}\n\n"
"Original message: {message}"
)
def _make_llm():
return get_router_llm()
async def classify_intent(
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> str:
"""Use gpt-4o-mini to classify intent and return the matching agent name.
Falls back to ``task_agent`` when the registry is empty or the model
returns a name that is not registered.
"""
agents = reg.list_agents()
if not agents:
return _FALLBACK_AGENT
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
# Truncate context to keep the classification prompt short
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
llm = _make_llm()
response = await llm.ainvoke(
[SystemMessage(content=system), HumanMessage(content=human)]
)
agent_name = str(response.content).strip().lower()
known = {a["name"] for a in agents}
return agent_name if agent_name in known else _FALLBACK_AGENT
async def route_single(
agent_name: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
response_text = await reg.call_agent(agent_name, message, context)
return ChatResponse(response=response_text)
async def route_pipeline(
agent_names: list[str],
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Execute agents sequentially; each agent receives previous results in context.
A final LLM synthesis call merges all results into one coherent response.
"""
previous_results: list[str] = []
for agent_name in agent_names:
ctx = {**context, "previous_results": list(previous_results)}
result = await reg.call_agent(agent_name, message, ctx)
previous_results.append(result)
results_str = "\n\n".join(
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
)
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
llm = _make_llm()
synthesis = await llm.ainvoke([HumanMessage(content=human)])
return ChatResponse(response=str(synthesis.content))
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
"""Build an ``ExecutionPlan`` for the resolved agent.
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
If a default template exists for the agent, an LLM step is emitted;
otherwise a plain ``handle`` action step is used.
"""
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
template_id = f"tpl_{agent_name}_default"
builder = ExecutionPlanBuilder(agent_name)
if template_registry.has(template_id):
builder.add_llm_step(template_id, {"message": message})
else:
builder.add_step("handle", {"message": message})
return builder.build()
async def orchestrate(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> ChatResponse | ExecutionPlan:
"""Main orchestration entry point.
* Classifies the user's intent to select an agent.
* ``execution_mode == 'direct'``: routes to the agent and returns a
``ChatResponse``.
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
resolved agent and a template-ID-only step (prompt IP stays server-side).
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
if request.execution_mode == "direct":
return await route_single(agent_name, request.message, context, reg)
# plan mode — return plan, do not execute
return _build_plan(agent_name, request.message)
async def orchestrate_v3(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
) -> tuple[str, ChatAgent]:
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
Classifies intent and instantiates the matching agent. The caller is responsible
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
return agent_name, reg.get(agent_name)
async def orchestrate_v3_stream(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
agent_holder: list | None = None,
) -> AsyncGenerator[tuple[str, str], None]:
"""v3 streaming orchestration — yields (agent_name, token) pairs.
The first yield always carries the agent_name with an empty token so that
callers (e.g. FloatingFormatter) can detect the routing domain before any text
tokens arrive.
If *agent_holder* is provided (a list), the agent instance is appended so
callers can access ``agent.tool_results`` after the stream completes.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
agent = reg.get(agent_name)
if agent_holder is not None:
agent_holder.append(agent)
yield agent_name, "" # domain signal — no token yet
async for token in agent.handle_stream(message, context):
yield agent_name, token
async def orchestrate_stream(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> AsyncGenerator[str, None]:
"""Streaming orchestration — yields plain text chunks only.
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
wrapping each chunk in a ``text_chunk`` frame and sending the final
``final`` frame once the generator is exhausted.
Agents do not yet support token-level streaming; the full response is
fetched first (which may involve multiple WS round-trips for tool calls),
then emitted in fixed-size chunks.
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
response_text = await reg.call_agent(agent_name, request.message, context)
chunk_size = 50
for i in range(0, len(response_text), chunk_size):
yield response_text[i : i + chunk_size]