"""Deep orchestrator-worker graphs for home and floating chat contexts."""
from __future__ import annotations
import asyncio
import json
import logging
import operator
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Literal, TypedDict
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS
from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS
from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
from app.db import async_session
logger = logging.getLogger(__name__)
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
class WorkerTask(BaseModel):
worker: WorkerName
instruction: str
class WorkerSummary(BaseModel):
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
class WorkerPlan(BaseModel):
tasks: list[WorkerTask] = Field(default_factory=list)
floating_domain: FloatingDomain | None = None
class WorkerResult(TypedDict):
worker: WorkerName
instruction: str
response: str
entity_ids: dict[str, list[str]]
class OrchestratorState(TypedDict, total=False):
user_id: str
user_message: str
context: dict[str, Any]
memory_context: dict[str, Any]
plan: list[dict[str, Any]]
floating_domain: FloatingDomain
task: dict[str, Any]
worker_results: list[WorkerResult]
final_response: str
class GraphState(OrchestratorState):
worker_results: list[WorkerResult]
class ReducerState(OrchestratorState):
worker_results: list[WorkerResult]
class AggregatedState(TypedDict, total=False):
worker_results: list[WorkerResult]
WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
"task_agent": {
"prompt": TASK_SYSTEM_PROMPT,
"tools": TASK_TOOLS,
"tag": "task",
"table": "tasks",
"floating_domain": "tasks",
},
"project_agent": {
"prompt": PROJECT_SYSTEM_PROMPT,
"tools": PROJECT_TOOLS,
"tag": "project",
"table": "projects",
"floating_domain": "projects",
},
"note_agent": {
"prompt": NOTE_SYSTEM_PROMPT,
"tools": NOTE_TOOLS,
"tag": "note",
"table": "notes",
"floating_domain": "notes",
},
"timeline_agent": {
"prompt": TIMELINE_SYSTEM_PROMPT,
"tools": TIMELINE_TOOLS,
"tag": "timeline",
"table": "timelines",
"floating_domain": "timelines",
},
}
_HOME_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator. Plan which workers should be invoked for the user request. "
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
"Return only the workers needed."
)
_FLOATING_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
"as one of: tasks, projects, notes, timelines."
)
_HOME_SYNTH_SYSTEM = (
"You are the final response synthesizer. Return markdown only. "
"Embed inline component tags when relevant: [ids], [ids], "
"[ids], [ids], and {json}. "
"Only include IDs that are truly relevant to the request."
)
_FLOATING_SYNTH_SYSTEM = (
"You are the final response synthesizer for floating UI context. "
"Return concise markdown and stay focused on the requested scope."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
lowered = message.lower()
tasks: list[WorkerTask] = []
if any(k in lowered for k in ["task", "todo", "deadline", "due"]):
tasks.append(WorkerTask(worker="task_agent", instruction=message))
if any(k in lowered for k in ["project", "client", "milestone"]):
tasks.append(WorkerTask(worker="project_agent", instruction=message))
if any(k in lowered for k in ["note", "document", "memo"]):
tasks.append(WorkerTask(worker="note_agent", instruction=message))
if any(k in lowered for k in ["timeline", "event", "schedule", "release"]):
tasks.append(WorkerTask(worker="timeline_agent", instruction=message))
if not tasks:
tasks = [WorkerTask(worker="task_agent", instruction=message)]
domain: FloatingDomain | None = None
if floating:
domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
return WorkerPlan(tasks=tasks, floating_domain=domain)
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
llm = get_llm()
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
prompt_payload = {
"message": message,
"context": context,
"workers": list(WORKER_CONFIG.keys()),
}
messages = [
SystemMessage(content=system),
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
]
try:
structured_llm = llm.with_structured_output(WorkerPlan)
plan = await structured_llm.ainvoke(messages)
if isinstance(plan, WorkerPlan):
if not plan.tasks:
return _fallback_plan(message, floating)
return plan
except Exception as exc:
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
return _fallback_plan(message, floating)
def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]:
out: dict[str, list[str]] = {
"task": [],
"project": [],
"note": [],
"timeline": [],
}
table_to_tag = {
"tasks": "task",
"projects": "project",
"notes": "note",
"timelines": "timeline",
}
for item in tool_results:
table = item.get("table")
tag = table_to_tag.get(table)
if tag is None:
continue
payload = item.get("data") or {}
rows: list[dict[str, Any]] = []
row = payload.get("row")
if isinstance(row, dict):
rows.append(row)
if isinstance(payload.get("rows"), list):
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
if isinstance(payload.get("results"), list):
rows.extend([r for r in payload["results"] if isinstance(r, dict)])
for r in rows:
entity_id = r.get("id")
if isinstance(entity_id, str) and entity_id not in out[tag]:
out[tag].append(entity_id)
return out
async def _run_tool_loop(
worker: WorkerName,
instruction: str,
context: dict[str, Any],
) -> tuple[str, list[dict[str, Any]]]:
worker_prompt = WORKER_CONFIG[worker]["prompt"]
tools = WORKER_CONFIG[worker]["tools"]
llm = get_llm()
llm_with_tools = llm.bind_tools(tools) if tools else llm
messages: list[Any] = [
SystemMessage(content=worker_prompt),
HumanMessage(
content=(
"Worker instruction:\n"
f"{instruction}\n\n"
"Conversation context:\n"
f"{json.dumps(context, ensure_ascii=True)[:2000]}"
)
),
]
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(6):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content), collected
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
structured_llm = llm.with_structured_output(WorkerSummary)
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
final_summary = await structured_llm.ainvoke(messages)
if isinstance(final_summary, WorkerSummary):
return final_summary.summary, collected
return str(final_summary), collected
finally:
clear_tool_result_collector()
def _worker_node(worker: WorkerName):
async def _node(state: GraphState) -> AggregatedState:
task_payload = state.get("task") or {}
if task_payload.get("worker") != worker:
return {"worker_results": []}
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
worker_context = {
"memory": state.get("memory_context", {}),
"context": state.get("context", {}),
}
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
return {
"worker_results": [
{
"worker": worker,
"instruction": instruction,
"response": response,
"entity_ids": _extract_entity_ids(tool_results),
}
]
}
return _node
def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
worker_results = state.get("worker_results", [])
formatted_results = []
for result in worker_results:
formatted_results.append(
{
"worker": result.get("worker"),
"instruction": result.get("instruction"),
"response": result.get("response"),
"entity_ids": result.get("entity_ids", {}),
}
)
payload = {
"user_message": state.get("user_message", ""),
"memory_context": state.get("memory_context", {}),
"worker_results": formatted_results,
"floating_domain": state.get("floating_domain") if floating else None,
}
return json.dumps(payload, ensure_ascii=True)
async def _stream_with_memory_tool(
*,
user_id: str,
system_prompt: str,
user_prompt: str,
) -> str:
@tool
async def update_core_memory(key: str, value: str) -> str:
"""Save stable user preference/profile data to core memory."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)
return f"Saved core memory key '{key}'."
llm = get_llm()
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
llm_with_tools = llm.bind_tools([update_core_memory])
for _ in range(2):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
break
for call in response.tool_calls:
if call["name"] != "update_core_memory":
messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
continue
tool_output = await update_core_memory.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
chunks: list[str] = []
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if not token:
continue
chunks.append(token)
return "".join(chunks)
def _synthesizer_node(floating: bool):
async def _node(state: GraphState) -> GraphState:
prompt = _build_synthesis_prompt(state, floating=floating)
system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM
final_response = await _stream_with_memory_tool(
user_id=str(state.get("user_id", "")),
system_prompt=system_prompt,
user_prompt=prompt,
)
return {"final_response": final_response}
return _node
async def _orchestrator_node_home(state: GraphState) -> GraphState:
if state.get("plan"):
return {}
context = {**state.get("context", {}), **state.get("memory_context", {})}
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
return {"plan": [task.model_dump() for task in plan.tasks]}
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
if state.get("plan"):
return {}
context = {**state.get("context", {}), **state.get("memory_context", {})}
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True)
floating_domain = plan.floating_domain
if floating_domain is None and plan.tasks:
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
return {
"plan": [task.model_dump() for task in plan.tasks],
"floating_domain": floating_domain or "tasks",
}
def _route_workers(state: GraphState) -> list[Send] | str:
plan = state.get("plan", [])
if not plan:
return "synthesizer"
sends: list[Send] = []
for task in plan:
worker = task.get("worker")
if worker in WORKER_CONFIG:
sends.append(Send(worker, {"task": task}))
return sends or "synthesizer"
def _build_graph(*, floating: bool):
builder = StateGraph(GraphState)
orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home
builder.add_node("orchestrator", orchestrator_node)
for worker in WORKER_CONFIG:
builder.add_node(worker, _worker_node(worker))
builder.add_node("synthesizer", _synthesizer_node(floating=floating))
builder.add_edge(START, "orchestrator")
builder.add_conditional_edges(
"orchestrator",
_route_workers,
["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"],
)
for worker in WORKER_CONFIG:
builder.add_edge(worker, "synthesizer")
builder.add_edge("synthesizer", END)
return builder.compile()
HOME_GRAPH = _build_graph(floating=False)
FLOATING_GRAPH = _build_graph(floating=True)
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
state = await HOME_GRAPH.ainvoke(
{
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"worker_results": [],
}
)
return str(state.get("final_response", ""))
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
plan = await _plan_with_llm(message, context, floating=True)
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
state = await FLOATING_GRAPH.ainvoke(
{
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"plan": [task.model_dump() for task in plan.tasks],
"floating_domain": domain,
"worker_results": [],
}
)
return str(state.get("final_response", "")), str(domain)
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
state_input = {
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"worker_results": [],
}
async for event in HOME_GRAPH.astream_events(state_input, version="v2"):
kind = event["event"]
if kind == "on_chat_model_stream":
node_name = event.get("metadata", {}).get("langgraph_node")
if node_name == "synthesizer":
chunk = event["data"]["chunk"]
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
plan = await _plan_with_llm(message, context, floating=True)
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
yield "floating_domain", domain
state_input = {
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"plan": [t.model_dump() for t in plan.tasks],
"floating_domain": domain,
"worker_results": [],
}
async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"):
kind = event["event"]
if kind == "on_chat_model_stream":
node_name = event.get("metadata", {}).get("langgraph_node")
if node_name == "synthesizer":
chunk = event["data"]["chunk"]
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token