feat: migrate chat orchestration to deep langgraph workers
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for all agents."""
|
||||
"""Common base for non-chat agents still using the old base contract."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -28,190 +27,4 @@ class BaseAgent(ABC):
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
"""Override in subclasses to advertise capabilities."""
|
||||
return []
|
||||
|
||||
|
||||
class ChatAgent(BaseAgent):
|
||||
"""Base class for LLM-powered chat agents."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
||||
self.tool_results: list[dict] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
"""Process a user query and return a text response."""
|
||||
...
|
||||
|
||||
async def handle_stream(
|
||||
self, query: str, context: dict[str, Any]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming variant of handle().
|
||||
|
||||
Default: calls handle() and yields the full response as one chunk.
|
||||
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
||||
"""
|
||||
yield await self.handle(query, context)
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Any]:
|
||||
"""Return LangChain tool definitions available to this agent."""
|
||||
...
|
||||
|
||||
async def _tool_loop(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> str:
|
||||
"""Shared tool-calling loop.
|
||||
|
||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||
requesting tool calls or *max_iter* is reached, and returns the
|
||||
final text response. Captures raw execute_on_client results in
|
||||
``self.tool_results``.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return str(response.content)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — ask model for a final answer without tools
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
self.tool_results = collector
|
||||
|
||||
async def _tool_loop_stream(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming variant of ``_tool_loop``.
|
||||
|
||||
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
||||
tool calls). For the final response — when the model produces no further
|
||||
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
||||
Captures raw execute_on_client results in ``self.tool_results``.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
|
||||
if not response.tool_calls:
|
||||
# Stream the final answer — don't keep the ainvoke result.
|
||||
async for chunk in llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield str(chunk.content)
|
||||
return
|
||||
|
||||
messages.append(response)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — stream a final answer without tools
|
||||
async for chunk in llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield str(chunk.content)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
self.tool_results = collector
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Singleton registry for ChatAgent subclasses."""
|
||||
|
||||
_instance: AgentRegistry | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._agents: dict[str, type[ChatAgent]] = {}
|
||||
|
||||
def __new__(cls) -> AgentRegistry:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._agents = {}
|
||||
return cls._instance
|
||||
|
||||
# ── public API ───────────────────────────────────────────────────
|
||||
|
||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
||||
"""Class decorator — registers an agent by its name."""
|
||||
instance = agent_class()
|
||||
name = instance.get_name()
|
||||
self._agents[name] = agent_class
|
||||
return agent_class
|
||||
|
||||
def get(self, name: str) -> ChatAgent:
|
||||
"""Return a fresh instance of the named agent."""
|
||||
cls = self._agents.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"Agent not found: {name}")
|
||||
return cls()
|
||||
|
||||
def list_agents(self) -> list[dict[str, str]]:
|
||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
||||
result: list[dict[str, str]] = []
|
||||
for cls in self._agents.values():
|
||||
inst = cls()
|
||||
result.append(
|
||||
{"name": inst.get_name(), "description": inst.get_description()}
|
||||
)
|
||||
return result
|
||||
|
||||
async def call_agent(
|
||||
self, name: str, query: str, context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Instantiate the named agent and call its ``handle`` method."""
|
||||
agent = self.get(name)
|
||||
return await agent.handle(query, context)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = AgentRegistry()
|
||||
|
||||
576
app/core/deep_agent.py
Normal file
576
app/core/deep_agent.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""Deep orchestrator-worker graphs for home and floating chat contexts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
|
||||
from app.core.llm import get_llm
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
from app.db import async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
|
||||
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
||||
|
||||
|
||||
class WorkerTask(BaseModel):
|
||||
worker: WorkerName
|
||||
instruction: str
|
||||
|
||||
|
||||
class WorkerPlan(BaseModel):
|
||||
tasks: list[WorkerTask] = Field(default_factory=list)
|
||||
floating_domain: FloatingDomain | None = None
|
||||
|
||||
|
||||
class WorkerResult(TypedDict):
|
||||
worker: WorkerName
|
||||
instruction: str
|
||||
response: str
|
||||
entity_ids: dict[str, list[str]]
|
||||
|
||||
|
||||
class OrchestratorState(TypedDict, total=False):
|
||||
user_id: str
|
||||
user_message: str
|
||||
context: dict[str, Any]
|
||||
memory_context: dict[str, Any]
|
||||
plan: list[dict[str, Any]]
|
||||
floating_domain: FloatingDomain
|
||||
task: dict[str, Any]
|
||||
worker_results: list[WorkerResult]
|
||||
final_response: str
|
||||
stream_callback: Callable[[str], Awaitable[None]] | None
|
||||
|
||||
|
||||
class GraphState(OrchestratorState):
|
||||
worker_results: list[WorkerResult]
|
||||
|
||||
|
||||
class ReducerState(OrchestratorState):
|
||||
worker_results: list[WorkerResult]
|
||||
|
||||
|
||||
class AggregatedState(TypedDict, total=False):
|
||||
worker_results: list[WorkerResult]
|
||||
|
||||
|
||||
WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
|
||||
"task_agent": {
|
||||
"prompt": TASK_SYSTEM_PROMPT,
|
||||
"tools": TASK_TOOLS,
|
||||
"tag": "task",
|
||||
"table": "tasks",
|
||||
"floating_domain": "tasks",
|
||||
},
|
||||
"project_agent": {
|
||||
"prompt": PROJECT_SYSTEM_PROMPT,
|
||||
"tools": PROJECT_TOOLS,
|
||||
"tag": "project",
|
||||
"table": "projects",
|
||||
"floating_domain": "projects",
|
||||
},
|
||||
"note_agent": {
|
||||
"prompt": NOTE_SYSTEM_PROMPT,
|
||||
"tools": NOTE_TOOLS,
|
||||
"tag": "note",
|
||||
"table": "notes",
|
||||
"floating_domain": "notes",
|
||||
},
|
||||
"timeline_agent": {
|
||||
"prompt": TIMELINE_SYSTEM_PROMPT,
|
||||
"tools": TIMELINE_TOOLS,
|
||||
"tag": "timeline",
|
||||
"table": "timelines",
|
||||
"floating_domain": "timelines",
|
||||
},
|
||||
}
|
||||
|
||||
_HOME_ORCHESTRATOR_SYSTEM = (
|
||||
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
||||
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
||||
"Return only the workers needed."
|
||||
)
|
||||
|
||||
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
||||
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
||||
"as one of: tasks, projects, notes, timelines."
|
||||
)
|
||||
|
||||
_HOME_SYNTH_SYSTEM = (
|
||||
"You are the final response synthesizer. Return markdown only. "
|
||||
"Embed inline component tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, and <chart>{json}</chart>. "
|
||||
"Only include IDs that are truly relevant to the request."
|
||||
)
|
||||
|
||||
_FLOATING_SYNTH_SYSTEM = (
|
||||
"You are the final response synthesizer for floating UI context. "
|
||||
"Return concise markdown and stay focused on the requested scope."
|
||||
)
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
|
||||
lowered = message.lower()
|
||||
tasks: list[WorkerTask] = []
|
||||
|
||||
if any(k in lowered for k in ["task", "todo", "deadline", "due"]):
|
||||
tasks.append(WorkerTask(worker="task_agent", instruction=message))
|
||||
if any(k in lowered for k in ["project", "client", "milestone"]):
|
||||
tasks.append(WorkerTask(worker="project_agent", instruction=message))
|
||||
if any(k in lowered for k in ["note", "document", "memo"]):
|
||||
tasks.append(WorkerTask(worker="note_agent", instruction=message))
|
||||
if any(k in lowered for k in ["timeline", "event", "schedule", "release"]):
|
||||
tasks.append(WorkerTask(worker="timeline_agent", instruction=message))
|
||||
|
||||
if not tasks:
|
||||
tasks = [WorkerTask(worker="task_agent", instruction=message)]
|
||||
|
||||
domain: FloatingDomain | None = None
|
||||
if floating:
|
||||
domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
|
||||
|
||||
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
||||
|
||||
|
||||
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
||||
llm = get_llm()
|
||||
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
||||
|
||||
prompt_payload = {
|
||||
"message": message,
|
||||
"context": context,
|
||||
"workers": list(WORKER_CONFIG.keys()),
|
||||
}
|
||||
messages = [
|
||||
SystemMessage(content=system),
|
||||
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
|
||||
]
|
||||
|
||||
try:
|
||||
structured_llm = llm.with_structured_output(WorkerPlan)
|
||||
plan = await structured_llm.ainvoke(messages)
|
||||
if isinstance(plan, WorkerPlan):
|
||||
if not plan.tasks:
|
||||
return _fallback_plan(message, floating)
|
||||
return plan
|
||||
except Exception as exc:
|
||||
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
|
||||
|
||||
return _fallback_plan(message, floating)
|
||||
|
||||
|
||||
def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]:
|
||||
out: dict[str, list[str]] = {
|
||||
"task": [],
|
||||
"project": [],
|
||||
"note": [],
|
||||
"timeline": [],
|
||||
}
|
||||
table_to_tag = {
|
||||
"tasks": "task",
|
||||
"projects": "project",
|
||||
"notes": "note",
|
||||
"timelines": "timeline",
|
||||
}
|
||||
|
||||
for item in tool_results:
|
||||
table = item.get("table")
|
||||
tag = table_to_tag.get(table)
|
||||
if tag is None:
|
||||
continue
|
||||
|
||||
payload = item.get("data") or {}
|
||||
rows: list[dict[str, Any]] = []
|
||||
row = payload.get("row")
|
||||
if isinstance(row, dict):
|
||||
rows.append(row)
|
||||
if isinstance(payload.get("rows"), list):
|
||||
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
|
||||
if isinstance(payload.get("results"), list):
|
||||
rows.extend([r for r in payload["results"] if isinstance(r, dict)])
|
||||
|
||||
for r in rows:
|
||||
entity_id = r.get("id")
|
||||
if isinstance(entity_id, str) and entity_id not in out[tag]:
|
||||
out[tag].append(entity_id)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
async def _run_tool_loop(
|
||||
worker: WorkerName,
|
||||
instruction: str,
|
||||
context: dict[str, Any],
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
worker_prompt = WORKER_CONFIG[worker]["prompt"]
|
||||
tools = WORKER_CONFIG[worker]["tools"]
|
||||
|
||||
llm = get_llm()
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=worker_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Worker instruction:\n"
|
||||
f"{instruction}\n\n"
|
||||
"Conversation context:\n"
|
||||
f"{json.dumps(context, ensure_ascii=True)[:2000]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
try:
|
||||
for _ in range(6):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return _as_text(response.content), collected
|
||||
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
final = await llm.ainvoke(messages)
|
||||
return _as_text(final.content), collected
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
|
||||
|
||||
def _worker_node(worker: WorkerName):
|
||||
async def _node(state: GraphState) -> AggregatedState:
|
||||
task_payload = state.get("task") or {}
|
||||
if task_payload.get("worker") != worker:
|
||||
return {"worker_results": []}
|
||||
|
||||
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
|
||||
worker_context = {
|
||||
"memory": state.get("memory_context", {}),
|
||||
"context": state.get("context", {}),
|
||||
}
|
||||
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
|
||||
|
||||
return {
|
||||
"worker_results": [
|
||||
{
|
||||
"worker": worker,
|
||||
"instruction": instruction,
|
||||
"response": response,
|
||||
"entity_ids": _extract_entity_ids(tool_results),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return _node
|
||||
|
||||
|
||||
def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
|
||||
worker_results = state.get("worker_results", [])
|
||||
formatted_results = []
|
||||
for result in worker_results:
|
||||
formatted_results.append(
|
||||
{
|
||||
"worker": result.get("worker"),
|
||||
"instruction": result.get("instruction"),
|
||||
"response": result.get("response"),
|
||||
"entity_ids": result.get("entity_ids", {}),
|
||||
}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"user_message": state.get("user_message", ""),
|
||||
"memory_context": state.get("memory_context", {}),
|
||||
"worker_results": formatted_results,
|
||||
"floating_domain": state.get("floating_domain") if floating else None,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=True)
|
||||
|
||||
|
||||
async def _stream_with_memory_tool(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
stream_callback: Callable[[str], Awaitable[None]] | None,
|
||||
) -> str:
|
||||
@tool
|
||||
async def update_core_memory(key: str, value: str) -> str:
|
||||
"""Save stable user preference/profile data to core memory."""
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, key, value)
|
||||
return f"Saved core memory key '{key}'."
|
||||
|
||||
llm = get_llm()
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=user_prompt),
|
||||
]
|
||||
|
||||
llm_with_tools = llm.bind_tools([update_core_memory])
|
||||
|
||||
for _ in range(2):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
break
|
||||
|
||||
for call in response.tool_calls:
|
||||
if call["name"] != "update_core_memory":
|
||||
messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
|
||||
continue
|
||||
|
||||
tool_output = await update_core_memory.ainvoke(call.get("args", {}))
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in llm.astream(messages):
|
||||
token = _as_text(getattr(chunk, "content", ""))
|
||||
if not token:
|
||||
continue
|
||||
chunks.append(token)
|
||||
if stream_callback is not None:
|
||||
await stream_callback(token)
|
||||
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
def _synthesizer_node(floating: bool):
|
||||
async def _node(state: GraphState) -> GraphState:
|
||||
prompt = _build_synthesis_prompt(state, floating=floating)
|
||||
system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM
|
||||
|
||||
final_response = await _stream_with_memory_tool(
|
||||
user_id=str(state.get("user_id", "")),
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=prompt,
|
||||
stream_callback=state.get("stream_callback"),
|
||||
)
|
||||
|
||||
return {"final_response": final_response}
|
||||
|
||||
return _node
|
||||
|
||||
|
||||
async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
||||
if state.get("plan"):
|
||||
return {}
|
||||
|
||||
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
||||
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
||||
return {"plan": [task.model_dump() for task in plan.tasks]}
|
||||
|
||||
|
||||
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
||||
if state.get("plan"):
|
||||
return {}
|
||||
|
||||
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
||||
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True)
|
||||
floating_domain = plan.floating_domain
|
||||
if floating_domain is None and plan.tasks:
|
||||
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||
|
||||
return {
|
||||
"plan": [task.model_dump() for task in plan.tasks],
|
||||
"floating_domain": floating_domain or "tasks",
|
||||
}
|
||||
|
||||
|
||||
def _route_workers(state: GraphState) -> list[Send] | str:
|
||||
plan = state.get("plan", [])
|
||||
if not plan:
|
||||
return "synthesizer"
|
||||
|
||||
sends: list[Send] = []
|
||||
for task in plan:
|
||||
worker = task.get("worker")
|
||||
if worker in WORKER_CONFIG:
|
||||
sends.append(Send(worker, {"task": task}))
|
||||
|
||||
return sends or "synthesizer"
|
||||
|
||||
|
||||
def _build_graph(*, floating: bool):
|
||||
builder = StateGraph(GraphState)
|
||||
|
||||
orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home
|
||||
builder.add_node("orchestrator", orchestrator_node)
|
||||
for worker in WORKER_CONFIG:
|
||||
builder.add_node(worker, _worker_node(worker))
|
||||
builder.add_node("synthesizer", _synthesizer_node(floating=floating))
|
||||
|
||||
builder.add_edge(START, "orchestrator")
|
||||
builder.add_conditional_edges(
|
||||
"orchestrator",
|
||||
_route_workers,
|
||||
["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"],
|
||||
)
|
||||
for worker in WORKER_CONFIG:
|
||||
builder.add_edge(worker, "synthesizer")
|
||||
builder.add_edge("synthesizer", END)
|
||||
|
||||
return builder.compile()
|
||||
|
||||
|
||||
HOME_GRAPH = _build_graph(floating=False)
|
||||
FLOATING_GRAPH = _build_graph(floating=True)
|
||||
|
||||
|
||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
state = await HOME_GRAPH.ainvoke(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_message": message,
|
||||
"context": context,
|
||||
"memory_context": context,
|
||||
"worker_results": [],
|
||||
"stream_callback": None,
|
||||
}
|
||||
)
|
||||
return str(state.get("final_response", ""))
|
||||
|
||||
|
||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
||||
plan = await _plan_with_llm(message, context, floating=True)
|
||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||
|
||||
state = await FLOATING_GRAPH.ainvoke(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_message": message,
|
||||
"context": context,
|
||||
"memory_context": context,
|
||||
"plan": [task.model_dump() for task in plan.tasks],
|
||||
"floating_domain": domain,
|
||||
"worker_results": [],
|
||||
"stream_callback": None,
|
||||
}
|
||||
)
|
||||
return str(state.get("final_response", "")), str(domain)
|
||||
|
||||
|
||||
async def run_home_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def _on_token(token: str) -> None:
|
||||
await queue.put(token)
|
||||
|
||||
task = asyncio.create_task(
|
||||
HOME_GRAPH.ainvoke(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_message": message,
|
||||
"context": context,
|
||||
"memory_context": context,
|
||||
"worker_results": [],
|
||||
"stream_callback": _on_token,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
emitted = False
|
||||
while not task.done() or not queue.empty():
|
||||
try:
|
||||
token = await asyncio.wait_for(queue.get(), timeout=0.15)
|
||||
emitted = True
|
||||
yield "token", token
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
final_state = await task
|
||||
if not emitted and final_state.get("final_response"):
|
||||
yield "token", str(final_state["final_response"])
|
||||
|
||||
|
||||
async def run_floating_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
plan = await _plan_with_llm(message, context, floating=True)
|
||||
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
||||
yield "floating_domain", domain
|
||||
|
||||
queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def _on_token(token: str) -> None:
|
||||
await queue.put(token)
|
||||
|
||||
task = asyncio.create_task(
|
||||
FLOATING_GRAPH.ainvoke(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_message": message,
|
||||
"context": context,
|
||||
"memory_context": context,
|
||||
"plan": [t.model_dump() for t in plan.tasks],
|
||||
"floating_domain": domain,
|
||||
"worker_results": [],
|
||||
"stream_callback": _on_token,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
emitted = False
|
||||
while not task.done() or not queue.empty():
|
||||
try:
|
||||
token = await asyncio.wait_for(queue.get(), timeout=0.15)
|
||||
emitted = True
|
||||
yield "token", token
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
final_state = await task
|
||||
if not emitted and final_state.get("final_response"):
|
||||
yield "token", str(final_state["final_response"])
|
||||
@@ -1,222 +0,0 @@
|
||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import ExecutionPlan, PlanStep
|
||||
|
||||
|
||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
||||
|
||||
|
||||
class PromptTemplateRegistry:
|
||||
"""Server-side store mapping template IDs to prompt text.
|
||||
|
||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
||||
out of API responses.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._templates: dict[str, str] = {}
|
||||
|
||||
def register(self, template_id: str, prompt_text: str) -> None:
|
||||
self._templates[template_id] = prompt_text
|
||||
|
||||
def get(self, template_id: str) -> str:
|
||||
"""Resolve a template ID to its prompt text.
|
||||
|
||||
Raises ``KeyError`` if the template is not registered.
|
||||
"""
|
||||
text = self._templates.get(template_id)
|
||||
if text is None:
|
||||
raise KeyError(f"Template not found: {template_id!r}")
|
||||
return text
|
||||
|
||||
def has(self, template_id: str) -> bool:
|
||||
return template_id in self._templates
|
||||
|
||||
def list_ids(self) -> list[str]:
|
||||
"""Return all registered template IDs (never the text)."""
|
||||
return list(self._templates.keys())
|
||||
|
||||
|
||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
||||
|
||||
|
||||
class ExecutionPlanBuilder:
|
||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
||||
|
||||
Example::
|
||||
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, agent: str) -> None:
|
||||
self._agent = agent
|
||||
self._steps: list[PlanStep] = []
|
||||
|
||||
# ── step adders ──────────────────────────────────────────────────
|
||||
|
||||
def add_step(
|
||||
self, action: str, params: dict[str, Any] | None = None
|
||||
) -> ExecutionPlanBuilder:
|
||||
"""Append a generic action step with optional parameters."""
|
||||
self._steps.append(PlanStep(action=action, variables=params))
|
||||
return self
|
||||
|
||||
def add_llm_step(
|
||||
self, template_id: str, variables: dict[str, Any] | None = None
|
||||
) -> ExecutionPlanBuilder:
|
||||
"""Append an LLM step referencing a server-side template by ID."""
|
||||
self._steps.append(
|
||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
||||
)
|
||||
return self
|
||||
|
||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
||||
"""Append a step whose input comes from the output of an earlier step."""
|
||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
||||
return self
|
||||
|
||||
# ── build ────────────────────────────────────────────────────────
|
||||
|
||||
def build(self) -> ExecutionPlan:
|
||||
"""Validate step references and return the ``ExecutionPlan``.
|
||||
|
||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
||||
non-existent or future step index.
|
||||
"""
|
||||
for i, step in enumerate(self._steps):
|
||||
if step.data_from_step is not None:
|
||||
if not (0 <= step.data_from_step < i):
|
||||
raise ValueError(
|
||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
||||
f"reference a preceding step index in range 0..{i - 1}"
|
||||
)
|
||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
||||
|
||||
|
||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PlanCache:
|
||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
||||
|
||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
||||
The cache also serves as a runtime memoisation layer so that repeated
|
||||
identical intent classifications can skip re-building the plan.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 1000) -> None:
|
||||
self._maxsize = maxsize
|
||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
||||
|
||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
||||
if key in self._cache:
|
||||
del self._cache[key] # remove so re-insertion places it at the end
|
||||
elif len(self._cache) >= self._maxsize:
|
||||
self._cache.popitem(last=False) # evict least-recently-used
|
||||
self._cache[key] = plan
|
||||
|
||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
||||
|
||||
Accessing a plan marks it as most-recently used.
|
||||
"""
|
||||
if key not in self._cache:
|
||||
return None
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
|
||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
||||
"""Return all cached plans (most-recently used last)."""
|
||||
return list(self._cache.values())
|
||||
|
||||
|
||||
# ── Module-level singletons ───────────────────────────────────────────
|
||||
|
||||
template_registry = PromptTemplateRegistry()
|
||||
plan_cache = PlanCache()
|
||||
|
||||
|
||||
def _register_builtin_templates() -> None:
|
||||
"""Register the built-in server-side prompt templates.
|
||||
|
||||
These strings never leave the server. Clients only receive the IDs.
|
||||
"""
|
||||
_tpls: dict[str, str] = {
|
||||
"tpl_task_agent_default": (
|
||||
"You are a task management assistant. Help the user create, update, "
|
||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
||||
"done) and priority values (high, medium, low) from the workspace model."
|
||||
),
|
||||
"tpl_timeline_agent_default": (
|
||||
"You are a project timeline assistant. Help the user create and manage "
|
||||
"milestone timelines on their projects. Every timeline requires a "
|
||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
||||
),
|
||||
"tpl_project_agent_default": (
|
||||
"You are a project management assistant. Help the user create, find, "
|
||||
"update, and archive projects. Projects have a name, an optional client, "
|
||||
"and a status of either active or archived."
|
||||
),
|
||||
"tpl_note_agent_default": (
|
||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
||||
),
|
||||
"tpl_task_extract_from_project": (
|
||||
"Extract all actionable tasks from the provided project context. "
|
||||
"Return a structured list of tasks, each with a title, inferred priority "
|
||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
||||
"milliseconds where a deadline can be inferred."
|
||||
),
|
||||
"tpl_note_weekly_summary": (
|
||||
"Generate a weekly project summary note from the provided workspace data. "
|
||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
||||
"and upcoming timelines. Format the output as clean Markdown."
|
||||
),
|
||||
}
|
||||
for tid, text in _tpls.items():
|
||||
template_registry.register(tid, text)
|
||||
|
||||
|
||||
def _load_playbooks() -> None:
|
||||
"""Pre-build and cache the built-in playbooks."""
|
||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
||||
(
|
||||
"create_tasks_from_project",
|
||||
ExecutionPlanBuilder("project_agent")
|
||||
.add_llm_step(
|
||||
"tpl_task_extract_from_project",
|
||||
{"source": "project_context"},
|
||||
)
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build(),
|
||||
),
|
||||
(
|
||||
"generate_weekly_note",
|
||||
ExecutionPlanBuilder("note_agent")
|
||||
.add_llm_step(
|
||||
"tpl_note_weekly_summary",
|
||||
{"period": "last_7_days"},
|
||||
)
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build(),
|
||||
),
|
||||
]
|
||||
for key, plan in playbooks:
|
||||
plan_cache.cache_plan(key, plan)
|
||||
|
||||
|
||||
# Initialise on module load
|
||||
_register_builtin_templates()
|
||||
_load_playbooks()
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
from app.core.llm import get_router_llm
|
||||
from app.core.agent_registry import registry as _default_registry
|
||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||
|
||||
_FALLBACK_AGENT = "task_agent"
|
||||
|
||||
_CLASSIFY_SYSTEM = (
|
||||
"You are an intent classifier. Given the user message and context, decide "
|
||||
"which agent to route to.\n"
|
||||
"Available agents: {agents}\n"
|
||||
"Respond with just the agent name, nothing else."
|
||||
)
|
||||
|
||||
_SYNTHESIZE_HUMAN = (
|
||||
"Combine the following agent results into one coherent response.\n\n"
|
||||
"Agent results:\n{results}\n\n"
|
||||
"Original message: {message}"
|
||||
)
|
||||
|
||||
|
||||
def _make_llm():
|
||||
return get_router_llm()
|
||||
|
||||
|
||||
async def classify_intent(
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> str:
|
||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
||||
|
||||
Falls back to ``task_agent`` when the registry is empty or the model
|
||||
returns a name that is not registered.
|
||||
"""
|
||||
agents = reg.list_agents()
|
||||
if not agents:
|
||||
return _FALLBACK_AGENT
|
||||
|
||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
||||
# Truncate context to keep the classification prompt short
|
||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
||||
|
||||
llm = _make_llm()
|
||||
response = await llm.ainvoke(
|
||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
||||
)
|
||||
|
||||
agent_name = str(response.content).strip().lower()
|
||||
known = {a["name"] for a in agents}
|
||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
||||
|
||||
|
||||
async def route_single(
|
||||
agent_name: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> ChatResponse:
|
||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
||||
response_text = await reg.call_agent(agent_name, message, context)
|
||||
return ChatResponse(response=response_text)
|
||||
|
||||
|
||||
async def route_pipeline(
|
||||
agent_names: list[str],
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> ChatResponse:
|
||||
"""Execute agents sequentially; each agent receives previous results in context.
|
||||
|
||||
A final LLM synthesis call merges all results into one coherent response.
|
||||
"""
|
||||
previous_results: list[str] = []
|
||||
|
||||
for agent_name in agent_names:
|
||||
ctx = {**context, "previous_results": list(previous_results)}
|
||||
result = await reg.call_agent(agent_name, message, ctx)
|
||||
previous_results.append(result)
|
||||
|
||||
results_str = "\n\n".join(
|
||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
||||
)
|
||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
||||
llm = _make_llm()
|
||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
||||
return ChatResponse(response=str(synthesis.content))
|
||||
|
||||
|
||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
||||
|
||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
||||
If a default template exists for the agent, an LLM step is emitted;
|
||||
otherwise a plain ``handle`` action step is used.
|
||||
"""
|
||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
||||
|
||||
template_id = f"tpl_{agent_name}_default"
|
||||
builder = ExecutionPlanBuilder(agent_name)
|
||||
if template_registry.has(template_id):
|
||||
builder.add_llm_step(template_id, {"message": message})
|
||||
else:
|
||||
builder.add_step("handle", {"message": message})
|
||||
return builder.build()
|
||||
|
||||
|
||||
async def orchestrate(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> ChatResponse | ExecutionPlan:
|
||||
"""Main orchestration entry point.
|
||||
|
||||
* Classifies the user's intent to select an agent.
|
||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
||||
``ChatResponse``.
|
||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
|
||||
context = request.context.model_dump()
|
||||
agent_name = await classify_intent(request.message, context, reg)
|
||||
|
||||
if request.execution_mode == "direct":
|
||||
return await route_single(agent_name, request.message, context, reg)
|
||||
|
||||
# plan mode — return plan, do not execute
|
||||
return _build_plan(agent_name, request.message)
|
||||
|
||||
|
||||
async def orchestrate_v3(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> tuple[str, ChatAgent]:
|
||||
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
||||
|
||||
Classifies intent and instantiates the matching agent. The caller is responsible
|
||||
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
agent_name = await classify_intent(message, context, reg)
|
||||
return agent_name, reg.get(agent_name)
|
||||
|
||||
|
||||
async def orchestrate_v3_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry | None = None,
|
||||
agent_holder: list | None = None,
|
||||
) -> AsyncGenerator[tuple[str, str], None]:
|
||||
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
||||
|
||||
The first yield always carries the agent_name with an empty token so that
|
||||
callers (e.g. FloatingFormatter) can detect the routing domain before any text
|
||||
tokens arrive.
|
||||
|
||||
If *agent_holder* is provided (a list), the agent instance is appended so
|
||||
callers can access ``agent.tool_results`` after the stream completes.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
agent_name = await classify_intent(message, context, reg)
|
||||
agent = reg.get(agent_name)
|
||||
if agent_holder is not None:
|
||||
agent_holder.append(agent)
|
||||
yield agent_name, "" # domain signal — no token yet
|
||||
async for token in agent.handle_stream(message, context):
|
||||
yield agent_name, token
|
||||
|
||||
|
||||
async def orchestrate_stream(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming orchestration — yields plain text chunks only.
|
||||
|
||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
||||
``final`` frame once the generator is exhausted.
|
||||
|
||||
Agents do not yet support token-level streaming; the full response is
|
||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
||||
then emitted in fixed-size chunks.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
|
||||
context = request.context.model_dump()
|
||||
agent_name = await classify_intent(request.message, context, reg)
|
||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
||||
|
||||
chunk_size = 50
|
||||
for i in range(0, len(response_text), chunk_size):
|
||||
yield response_text[i : i + chunk_size]
|
||||
@@ -1,244 +1,43 @@
|
||||
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
||||
|
||||
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
||||
FloatingFormatter: produces floating_domain, stream_text, stream_end
|
||||
"""
|
||||
"""Output formatter for deep-agent stream events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import (
|
||||
WsFloatingDomain,
|
||||
WsStreamBlock,
|
||||
WsStreamEnd,
|
||||
WsStreamStart,
|
||||
WsStreamText,
|
||||
)
|
||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
||||
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
||||
|
||||
# Map agent name → floating domain
|
||||
_AGENT_DOMAIN: dict[str, str] = {
|
||||
"task_agent": "tasks",
|
||||
"timeline_agent": "timelines",
|
||||
"note_agent": "notes",
|
||||
"project_agent": "projects",
|
||||
}
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||
|
||||
|
||||
class HomeFormatter:
|
||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||
|
||||
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
||||
each with a ``type`` field:
|
||||
- ``text`` → yields WsStreamText immediately (word-by-word)
|
||||
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
||||
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
|
||||
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
||||
self.request_id = request_id
|
||||
self.tool_results = tool_results
|
||||
|
||||
async def format(
|
||||
self,
|
||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
|
||||
buffer = ""
|
||||
async for _agent_name, token in token_stream:
|
||||
if not token:
|
||||
continue
|
||||
buffer += token
|
||||
# Flush any complete JSON objects from the buffer
|
||||
async for frame in self._flush_complete_objects(buffer):
|
||||
buffer = "" # reset after flush
|
||||
yield frame
|
||||
break # only one flush per iteration; rest accumulates
|
||||
|
||||
# Flush any remaining content
|
||||
if buffer.strip():
|
||||
async for frame in self._flush_complete_objects(buffer, final=True):
|
||||
yield frame
|
||||
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
|
||||
async def _flush_complete_objects(
|
||||
self, text: str, final: bool = False
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
"""Try to parse and yield all complete JSON objects from *text*.
|
||||
|
||||
Yields nothing if text is incomplete JSON (unless *final* is True,
|
||||
in which case remaining text is emitted as plain stream_text).
|
||||
"""
|
||||
remaining = text.strip()
|
||||
while remaining:
|
||||
# Fast path: plain text (not JSON)
|
||||
if not remaining.startswith("{"):
|
||||
# Yield as plain text chunk
|
||||
newline_idx = remaining.find("\n")
|
||||
if newline_idx == -1:
|
||||
if final:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
else:
|
||||
return # accumulate more
|
||||
else:
|
||||
line = remaining[:newline_idx].strip()
|
||||
remaining = remaining[newline_idx + 1:].strip()
|
||||
if line:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=line)
|
||||
continue
|
||||
|
||||
# Try to decode a JSON object
|
||||
try:
|
||||
obj, end_idx = _try_parse_json(remaining)
|
||||
except ValueError:
|
||||
if final:
|
||||
# Emit as raw text if we can't parse
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
return
|
||||
|
||||
if obj is None:
|
||||
if final:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
return # incomplete — need more tokens
|
||||
|
||||
remaining = remaining[end_idx:].strip()
|
||||
block_type = obj.get("type")
|
||||
|
||||
frame = self._dispatch_block(obj, block_type)
|
||||
if frame is not None:
|
||||
yield frame
|
||||
|
||||
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
||||
if block_type == "text":
|
||||
content = obj.get("content", "")
|
||||
if content:
|
||||
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
||||
return None
|
||||
|
||||
if block_type == "chart":
|
||||
chart_type = obj.get("chartType")
|
||||
if chart_type not in _VALID_CHART_TYPES:
|
||||
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
||||
return None
|
||||
if not isinstance(obj.get("data"), list):
|
||||
logger.warning("HomeFormatter: chart missing data array — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="chart",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
if block_type == "entity_ref":
|
||||
entity = obj.get("entity")
|
||||
resolved = self._resolve_entity(entity)
|
||||
if resolved is None:
|
||||
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="entity_ref",
|
||||
data={"entity": entity, "items": resolved},
|
||||
)
|
||||
|
||||
if block_type == "table":
|
||||
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
||||
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="table",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
if block_type == "timeline":
|
||||
if not isinstance(obj.get("timelines"), list):
|
||||
logger.warning("HomeFormatter: timeline missing timelines — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="timeline",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
||||
return None
|
||||
|
||||
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
||||
"""Find matching items in tool_results by entity type."""
|
||||
if not entity:
|
||||
return None
|
||||
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
||||
return matches if matches else None
|
||||
|
||||
|
||||
class FloatingFormatter:
|
||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||
|
||||
Emits floating_domain immediately (from agent_name), then streams all tokens
|
||||
as plain stream_text — no block parsing for floating context.
|
||||
"""
|
||||
class StreamFormatter:
|
||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
|
||||
async def format(
|
||||
self,
|
||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
domain_sent = False
|
||||
started = False
|
||||
|
||||
async for agent_name, token in token_stream:
|
||||
if not domain_sent:
|
||||
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain=domain, # type: ignore[arg-type]
|
||||
)
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "floating_domain":
|
||||
yield WsFloatingDomain(request_id=self.request_id, domain=str(data))
|
||||
continue
|
||||
|
||||
if event_type != "token":
|
||||
continue
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
domain_sent = True
|
||||
started = True
|
||||
|
||||
if token:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=token)
|
||||
text = str(data or "")
|
||||
if text:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
||||
"""Attempt to parse the first complete JSON object from *text*.
|
||||
|
||||
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
||||
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
||||
"""
|
||||
decoder = json.JSONDecoder()
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(text)
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError("Expected JSON object")
|
||||
return obj, end_idx
|
||||
except json.JSONDecodeError as exc:
|
||||
# Incomplete JSON — need more tokens
|
||||
if "Unterminated" in str(exc) or exc.pos == len(text):
|
||||
return None, 0
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
Reference in New Issue
Block a user