577 lines
18 KiB
Python
577 lines
18 KiB
Python
"""Deep orchestrator-worker graphs for home and floating chat contexts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import operator
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
from typing import Any, Literal, TypedDict
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
from langchain_core.tools import tool
|
|
from langgraph.constants import END, START
|
|
from langgraph.graph import StateGraph
|
|
from langgraph.types import Send
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS
|
|
from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS
|
|
from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
|
|
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
|
|
from app.core.llm import get_llm
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
from app.db import async_session
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
|
|
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
|
|
|
|
|
class WorkerTask(BaseModel):
|
|
worker: WorkerName
|
|
instruction: str
|
|
|
|
|
|
class WorkerPlan(BaseModel):
|
|
tasks: list[WorkerTask] = Field(default_factory=list)
|
|
floating_domain: FloatingDomain | None = None
|
|
|
|
|
|
class WorkerResult(TypedDict):
|
|
worker: WorkerName
|
|
instruction: str
|
|
response: str
|
|
entity_ids: dict[str, list[str]]
|
|
|
|
|
|
class OrchestratorState(TypedDict, total=False):
|
|
user_id: str
|
|
user_message: str
|
|
context: dict[str, Any]
|
|
memory_context: dict[str, Any]
|
|
plan: list[dict[str, Any]]
|
|
floating_domain: FloatingDomain
|
|
task: dict[str, Any]
|
|
worker_results: list[WorkerResult]
|
|
final_response: str
|
|
stream_callback: Callable[[str], Awaitable[None]] | None
|
|
|
|
|
|
class GraphState(OrchestratorState):
|
|
worker_results: list[WorkerResult]
|
|
|
|
|
|
class ReducerState(OrchestratorState):
|
|
worker_results: list[WorkerResult]
|
|
|
|
|
|
class AggregatedState(TypedDict, total=False):
|
|
worker_results: list[WorkerResult]
|
|
|
|
|
|
WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
|
|
"task_agent": {
|
|
"prompt": TASK_SYSTEM_PROMPT,
|
|
"tools": TASK_TOOLS,
|
|
"tag": "task",
|
|
"table": "tasks",
|
|
"floating_domain": "tasks",
|
|
},
|
|
"project_agent": {
|
|
"prompt": PROJECT_SYSTEM_PROMPT,
|
|
"tools": PROJECT_TOOLS,
|
|
"tag": "project",
|
|
"table": "projects",
|
|
"floating_domain": "projects",
|
|
},
|
|
"note_agent": {
|
|
"prompt": NOTE_SYSTEM_PROMPT,
|
|
"tools": NOTE_TOOLS,
|
|
"tag": "note",
|
|
"table": "notes",
|
|
"floating_domain": "notes",
|
|
},
|
|
"timeline_agent": {
|
|
"prompt": TIMELINE_SYSTEM_PROMPT,
|
|
"tools": TIMELINE_TOOLS,
|
|
"tag": "timeline",
|
|
"table": "timelines",
|
|
"floating_domain": "timelines",
|
|
},
|
|
}
|
|
|
|
_HOME_ORCHESTRATOR_SYSTEM = (
|
|
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
|
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
|
"Return only the workers needed."
|
|
)
|
|
|
|
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
|
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
|
"as one of: tasks, projects, notes, timelines."
|
|
)
|
|
|
|
_HOME_SYNTH_SYSTEM = (
|
|
"You are the final response synthesizer. Return markdown only. "
|
|
"Embed inline component tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, and <chart>{json}</chart>. "
|
|
"Only include IDs that are truly relevant to the request."
|
|
)
|
|
|
|
_FLOATING_SYNTH_SYSTEM = (
|
|
"You are the final response synthesizer for floating UI context. "
|
|
"Return concise markdown and stay focused on the requested scope."
|
|
)
|
|
|
|
|
|
def _as_text(content: Any) -> str:
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts)
|
|
return str(content)
|
|
|
|
|
|
def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
|
|
lowered = message.lower()
|
|
tasks: list[WorkerTask] = []
|
|
|
|
if any(k in lowered for k in ["task", "todo", "deadline", "due"]):
|
|
tasks.append(WorkerTask(worker="task_agent", instruction=message))
|
|
if any(k in lowered for k in ["project", "client", "milestone"]):
|
|
tasks.append(WorkerTask(worker="project_agent", instruction=message))
|
|
if any(k in lowered for k in ["note", "document", "memo"]):
|
|
tasks.append(WorkerTask(worker="note_agent", instruction=message))
|
|
if any(k in lowered for k in ["timeline", "event", "schedule", "release"]):
|
|
tasks.append(WorkerTask(worker="timeline_agent", instruction=message))
|
|
|
|
if not tasks:
|
|
tasks = [WorkerTask(worker="task_agent", instruction=message)]
|
|
|
|
domain: FloatingDomain | None = None
|
|
if floating:
|
|
domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
|
|
|
|
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
|
|
|
|
|
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
|
llm = get_llm()
|
|
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
|
|
|
prompt_payload = {
|
|
"message": message,
|
|
"context": context,
|
|
"workers": list(WORKER_CONFIG.keys()),
|
|
}
|
|
messages = [
|
|
SystemMessage(content=system),
|
|
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
|
|
]
|
|
|
|
try:
|
|
structured_llm = llm.with_structured_output(WorkerPlan)
|
|
plan = await structured_llm.ainvoke(messages)
|
|
if isinstance(plan, WorkerPlan):
|
|
if not plan.tasks:
|
|
return _fallback_plan(message, floating)
|
|
return plan
|
|
except Exception as exc:
|
|
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
|
|
|
|
return _fallback_plan(message, floating)
|
|
|
|
|
|
def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]:
|
|
out: dict[str, list[str]] = {
|
|
"task": [],
|
|
"project": [],
|
|
"note": [],
|
|
"timeline": [],
|
|
}
|
|
table_to_tag = {
|
|
"tasks": "task",
|
|
"projects": "project",
|
|
"notes": "note",
|
|
"timelines": "timeline",
|
|
}
|
|
|
|
for item in tool_results:
|
|
table = item.get("table")
|
|
tag = table_to_tag.get(table)
|
|
if tag is None:
|
|
continue
|
|
|
|
payload = item.get("data") or {}
|
|
rows: list[dict[str, Any]] = []
|
|
row = payload.get("row")
|
|
if isinstance(row, dict):
|
|
rows.append(row)
|
|
if isinstance(payload.get("rows"), list):
|
|
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
|
|
if isinstance(payload.get("results"), list):
|
|
rows.extend([r for r in payload["results"] if isinstance(r, dict)])
|
|
|
|
for r in rows:
|
|
entity_id = r.get("id")
|
|
if isinstance(entity_id, str) and entity_id not in out[tag]:
|
|
out[tag].append(entity_id)
|
|
|
|
return out
|
|
|
|
|
|
async def _run_tool_loop(
|
|
worker: WorkerName,
|
|
instruction: str,
|
|
context: dict[str, Any],
|
|
) -> tuple[str, list[dict[str, Any]]]:
|
|
worker_prompt = WORKER_CONFIG[worker]["prompt"]
|
|
tools = WORKER_CONFIG[worker]["tools"]
|
|
|
|
llm = get_llm()
|
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
|
|
messages: list[Any] = [
|
|
SystemMessage(content=worker_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
"Worker instruction:\n"
|
|
f"{instruction}\n\n"
|
|
"Conversation context:\n"
|
|
f"{json.dumps(context, ensure_ascii=True)[:2000]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(6):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
return _as_text(response.content), collected
|
|
|
|
tool_map = {t.name: t for t in tools}
|
|
for call in response.tool_calls:
|
|
tool_fn = tool_map.get(call["name"])
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call['name']}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
final = await llm.ainvoke(messages)
|
|
return _as_text(final.content), collected
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
def _worker_node(worker: WorkerName):
|
|
async def _node(state: GraphState) -> AggregatedState:
|
|
task_payload = state.get("task") or {}
|
|
if task_payload.get("worker") != worker:
|
|
return {"worker_results": []}
|
|
|
|
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
|
|
worker_context = {
|
|
"memory": state.get("memory_context", {}),
|
|
"context": state.get("context", {}),
|
|
}
|
|
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
|
|
|
|
return {
|
|
"worker_results": [
|
|
{
|
|
"worker": worker,
|
|
"instruction": instruction,
|
|
"response": response,
|
|
"entity_ids": _extract_entity_ids(tool_results),
|
|
}
|
|
]
|
|
}
|
|
|
|
return _node
|
|
|
|
|
|
def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
|
|
worker_results = state.get("worker_results", [])
|
|
formatted_results = []
|
|
for result in worker_results:
|
|
formatted_results.append(
|
|
{
|
|
"worker": result.get("worker"),
|
|
"instruction": result.get("instruction"),
|
|
"response": result.get("response"),
|
|
"entity_ids": result.get("entity_ids", {}),
|
|
}
|
|
)
|
|
|
|
payload = {
|
|
"user_message": state.get("user_message", ""),
|
|
"memory_context": state.get("memory_context", {}),
|
|
"worker_results": formatted_results,
|
|
"floating_domain": state.get("floating_domain") if floating else None,
|
|
}
|
|
return json.dumps(payload, ensure_ascii=True)
|
|
|
|
|
|
async def _stream_with_memory_tool(
|
|
*,
|
|
user_id: str,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
stream_callback: Callable[[str], Awaitable[None]] | None,
|
|
) -> str:
|
|
@tool
|
|
async def update_core_memory(key: str, value: str) -> str:
|
|
"""Save stable user preference/profile data to core memory."""
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.update_core(user_id, key, value)
|
|
return f"Saved core memory key '{key}'."
|
|
|
|
llm = get_llm()
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(content=user_prompt),
|
|
]
|
|
|
|
llm_with_tools = llm.bind_tools([update_core_memory])
|
|
|
|
for _ in range(2):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
break
|
|
|
|
for call in response.tool_calls:
|
|
if call["name"] != "update_core_memory":
|
|
messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
|
|
continue
|
|
|
|
tool_output = await update_core_memory.ainvoke(call.get("args", {}))
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
chunks: list[str] = []
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if not token:
|
|
continue
|
|
chunks.append(token)
|
|
if stream_callback is not None:
|
|
await stream_callback(token)
|
|
|
|
return "".join(chunks)
|
|
|
|
|
|
def _synthesizer_node(floating: bool):
|
|
async def _node(state: GraphState) -> GraphState:
|
|
prompt = _build_synthesis_prompt(state, floating=floating)
|
|
system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM
|
|
|
|
final_response = await _stream_with_memory_tool(
|
|
user_id=str(state.get("user_id", "")),
|
|
system_prompt=system_prompt,
|
|
user_prompt=prompt,
|
|
stream_callback=state.get("stream_callback"),
|
|
)
|
|
|
|
return {"final_response": final_response}
|
|
|
|
return _node
|
|
|
|
|
|
async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
|
if state.get("plan"):
|
|
return {}
|
|
|
|
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
|
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
|
return {"plan": [task.model_dump() for task in plan.tasks]}
|
|
|
|
|
|
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
|
|
if state.get("plan"):
|
|
return {}
|
|
|
|
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
|
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True)
|
|
floating_domain = plan.floating_domain
|
|
if floating_domain is None and plan.tasks:
|
|
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
|
|
|
return {
|
|
"plan": [task.model_dump() for task in plan.tasks],
|
|
"floating_domain": floating_domain or "tasks",
|
|
}
|
|
|
|
|
|
def _route_workers(state: GraphState) -> list[Send] | str:
|
|
plan = state.get("plan", [])
|
|
if not plan:
|
|
return "synthesizer"
|
|
|
|
sends: list[Send] = []
|
|
for task in plan:
|
|
worker = task.get("worker")
|
|
if worker in WORKER_CONFIG:
|
|
sends.append(Send(worker, {"task": task}))
|
|
|
|
return sends or "synthesizer"
|
|
|
|
|
|
def _build_graph(*, floating: bool):
|
|
builder = StateGraph(GraphState)
|
|
|
|
orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home
|
|
builder.add_node("orchestrator", orchestrator_node)
|
|
for worker in WORKER_CONFIG:
|
|
builder.add_node(worker, _worker_node(worker))
|
|
builder.add_node("synthesizer", _synthesizer_node(floating=floating))
|
|
|
|
builder.add_edge(START, "orchestrator")
|
|
builder.add_conditional_edges(
|
|
"orchestrator",
|
|
_route_workers,
|
|
["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"],
|
|
)
|
|
for worker in WORKER_CONFIG:
|
|
builder.add_edge(worker, "synthesizer")
|
|
builder.add_edge("synthesizer", END)
|
|
|
|
return builder.compile()
|
|
|
|
|
|
HOME_GRAPH = _build_graph(floating=False)
|
|
FLOATING_GRAPH = _build_graph(floating=True)
|
|
|
|
|
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|
state = await HOME_GRAPH.ainvoke(
|
|
{
|
|
"user_id": user_id,
|
|
"user_message": message,
|
|
"context": context,
|
|
"memory_context": context,
|
|
"worker_results": [],
|
|
"stream_callback": None,
|
|
}
|
|
)
|
|
return str(state.get("final_response", ""))
|
|
|
|
|
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
|
plan = await _plan_with_llm(message, context, floating=True)
|
|
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
|
|
|
state = await FLOATING_GRAPH.ainvoke(
|
|
{
|
|
"user_id": user_id,
|
|
"user_message": message,
|
|
"context": context,
|
|
"memory_context": context,
|
|
"plan": [task.model_dump() for task in plan.tasks],
|
|
"floating_domain": domain,
|
|
"worker_results": [],
|
|
"stream_callback": None,
|
|
}
|
|
)
|
|
return str(state.get("final_response", "")), str(domain)
|
|
|
|
|
|
async def run_home_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
|
|
async def _on_token(token: str) -> None:
|
|
await queue.put(token)
|
|
|
|
task = asyncio.create_task(
|
|
HOME_GRAPH.ainvoke(
|
|
{
|
|
"user_id": user_id,
|
|
"user_message": message,
|
|
"context": context,
|
|
"memory_context": context,
|
|
"worker_results": [],
|
|
"stream_callback": _on_token,
|
|
}
|
|
)
|
|
)
|
|
|
|
emitted = False
|
|
while not task.done() or not queue.empty():
|
|
try:
|
|
token = await asyncio.wait_for(queue.get(), timeout=0.15)
|
|
emitted = True
|
|
yield "token", token
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
final_state = await task
|
|
if not emitted and final_state.get("final_response"):
|
|
yield "token", str(final_state["final_response"])
|
|
|
|
|
|
async def run_floating_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
plan = await _plan_with_llm(message, context, floating=True)
|
|
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
|
|
yield "floating_domain", domain
|
|
|
|
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
|
|
async def _on_token(token: str) -> None:
|
|
await queue.put(token)
|
|
|
|
task = asyncio.create_task(
|
|
FLOATING_GRAPH.ainvoke(
|
|
{
|
|
"user_id": user_id,
|
|
"user_message": message,
|
|
"context": context,
|
|
"memory_context": context,
|
|
"plan": [t.model_dump() for t in plan.tasks],
|
|
"floating_domain": domain,
|
|
"worker_results": [],
|
|
"stream_callback": _on_token,
|
|
}
|
|
)
|
|
)
|
|
|
|
emitted = False
|
|
while not task.done() or not queue.empty():
|
|
try:
|
|
token = await asyncio.wait_for(queue.get(), timeout=0.15)
|
|
emitted = True
|
|
yield "token", token
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
final_state = await task
|
|
if not emitted and final_state.get("final_response"):
|
|
yield "token", str(final_state["final_response"])
|