"""Deep orchestrator-worker graphs for home and floating chat contexts."""
from __future__ import annotations
import asyncio
import json
import logging
import re
from collections.abc import AsyncGenerator, Awaitable, Callable
import operator
from typing import Annotated, Any, Literal, TypedDict
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS
from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS
from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
logger = logging.getLogger(__name__)
# Quick test switch: home requests run as one agent with all tools.
HOME_SINGLE_AGENT_TEST_MODE = True
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
class WorkerTask(BaseModel):
worker: WorkerName
instruction: str
class MemoryUpdate(BaseModel):
key: str = Field(description="The memory key to set or update.")
value: str = Field(description="The persistent fact or preference value.")
class WorkerSummary(BaseModel):
summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
class WorkerPlan(BaseModel):
tasks: list[WorkerTask] = Field(default_factory=list)
floating_domain: FloatingDomain | None = None
memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.")
class WorkerResult(TypedDict):
worker: WorkerName
instruction: str
response: str
entity_ids: dict[str, list[str]]
facts: dict[str, Any]
class OrchestratorState(TypedDict, total=False):
user_id: str
user_message: str
context: dict[str, Any]
memory_context: dict[str, Any]
plan: list[dict[str, Any]]
floating_domain: FloatingDomain
task: dict[str, Any]
worker_results: list[WorkerResult]
final_response: str
class GraphState(OrchestratorState):
worker_results: Annotated[list[WorkerResult], operator.add]
class ReducerState(OrchestratorState):
worker_results: list[WorkerResult]
class AggregatedState(TypedDict, total=False):
worker_results: list[WorkerResult]
WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
"task_agent": {
"prompt": TASK_SYSTEM_PROMPT,
"tools": TASK_TOOLS,
"tag": "task",
"table": "tasks",
"floating_domain": "tasks",
},
"project_agent": {
"prompt": PROJECT_SYSTEM_PROMPT,
"tools": PROJECT_TOOLS,
"tag": "project",
"table": "projects",
"floating_domain": "projects",
},
"note_agent": {
"prompt": NOTE_SYSTEM_PROMPT,
"tools": NOTE_TOOLS,
"tag": "note",
"table": "notes",
"floating_domain": "notes",
},
"timeline_agent": {
"prompt": TIMELINE_SYSTEM_PROMPT,
"tools": TIMELINE_TOOLS,
"tag": "timeline",
"table": "timelines",
"floating_domain": "timelines",
},
}
_HOME_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator. Plan which workers should be invoked for the user request. "
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
"Return JSON only with keys: tasks, floating_domain, memory_updates."
)
_FLOATING_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
"as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
)
_HOME_SYNTH_SYSTEM = (
"You are the final response synthesizer. Return markdown only. "
"Embed inline component tags when relevant: [ids], [ids], "
"[ids], [ids], and {json}. "
"Only include IDs that are truly relevant to the request. "
"Never invent missing values. If facts include a non-null clientId for a project, "
"do not claim that the project has no owner/client."
)
_FLOATING_SYNTH_SYSTEM = (
"You are the final response synthesizer for floating UI context. "
"Return concise markdown and stay focused on the requested scope."
)
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
"Always use tools for factual data retrieval before answering. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and embed inline tags when relevant: [ids], [ids], "
"[ids], [ids], {json}."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
lowered = message.lower()
tasks: list[WorkerTask] = []
if any(k in lowered for k in ["task", "todo", "deadline", "due"]):
tasks.append(WorkerTask(worker="task_agent", instruction=message))
if any(k in lowered for k in ["project", "client", "milestone"]):
tasks.append(WorkerTask(worker="project_agent", instruction=message))
if any(k in lowered for k in ["note", "document", "memo"]):
tasks.append(WorkerTask(worker="note_agent", instruction=message))
if any(k in lowered for k in ["timeline", "event", "schedule", "release"]):
tasks.append(WorkerTask(worker="timeline_agent", instruction=message))
if not tasks:
tasks = [WorkerTask(worker="task_agent", instruction=message)]
domain: FloatingDomain | None = None
if floating:
domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
return WorkerPlan(tasks=tasks, floating_domain=domain)
def _extract_json_object(text: str) -> dict[str, Any] | None:
"""Best-effort extraction of the first JSON object from model output."""
stripped = text.strip()
if not stripped:
return None
# Common case: model returns raw JSON object.
try:
payload = json.loads(stripped)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
pass
# Fenced JSON block fallback.
if "```" in stripped:
parts = stripped.split("```")
for part in parts:
candidate = part.strip()
if candidate.startswith("json"):
candidate = candidate[4:].strip()
try:
payload = json.loads(candidate)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
continue
return None
def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
"""Normalize loose model JSON into a validated WorkerPlan."""
tasks_raw = payload.get("tasks")
tasks: list[WorkerTask] = []
if isinstance(tasks_raw, list):
for item in tasks_raw:
if not isinstance(item, dict):
continue
worker = item.get("worker")
instruction = item.get("instruction")
if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
tasks.append(WorkerTask(worker=worker, instruction=instruction))
if not tasks:
return _fallback_plan(message, floating)
domain = payload.get("floating_domain")
floating_domain: FloatingDomain | None = None
if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
floating_domain = domain # type: ignore[assignment]
elif floating:
floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
memory_updates: list[MemoryUpdate] = []
updates_raw = payload.get("memory_updates")
if isinstance(updates_raw, list):
for item in updates_raw:
if isinstance(item, dict):
key = item.get("key")
value = item.get("value")
if isinstance(key, str) and isinstance(value, str) and key and value:
memory_updates.append(MemoryUpdate(key=key, value=value))
return WorkerPlan(
tasks=tasks,
floating_domain=floating_domain,
memory_updates=memory_updates,
)
def _needs_full_project_snapshot(message: str, floating: bool) -> bool:
"""Detect project status/update requests that should query all workers."""
if floating:
return False
lowered = message.lower()
has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"])
has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"])
return has_project and has_status_intent
def _build_full_project_snapshot_plan(message: str) -> WorkerPlan:
"""Build a deterministic all-workers plan for project status snapshots."""
project_hint = (
"Use context.context.resolved_project_id when present as project_id. "
"Do not pass project names as project_id."
)
return WorkerPlan(
tasks=[
WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"),
WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"),
WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"),
WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"),
]
)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [t for t in tokens if len(t) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
"""Resolve and inject project_id hints for home flows."""
prepared = dict(context)
if _needs_full_project_snapshot(message, floating=False):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200])
return prepared
def _all_tools() -> list[Any]:
tools: list[Any] = []
for config in WORKER_CONFIG.values():
tools.extend(config["tools"])
return tools
async def _run_home_single_agent(
user_id: str,
message: str,
context: dict[str, Any],
) -> str:
"""Single-agent test mode: one loop with all tools."""
prepared_context = await _prepare_home_context(message, context)
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
]
for _ in range(6):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
async def _run_home_single_agent_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
"""Streaming variant for single-agent home test mode."""
prepared_context = await _prepare_home_context(message, context)
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
]
for _ in range(6):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
return
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
if _needs_full_project_snapshot(message, floating):
logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200])
return _build_full_project_snapshot_plan(message)
llm = get_llm()
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
prompt_payload = {
"message": message,
"context": context,
"workers": list(WORKER_CONFIG.keys()),
}
messages = [
SystemMessage(content=system),
HumanMessage(
content=(
"Create a valid JSON object with this exact structure:\n"
'{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
'"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
"Rules:\n"
"- tasks must include at least one entry when possible\n"
"- use floating_domain only when relevant\n"
"- output JSON only (no markdown, no prose)\n\n"
f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
)
),
]
try:
response = await llm.ainvoke(messages)
payload = _extract_json_object(_as_text(response.content))
if payload is None:
raise ValueError("planner returned non-JSON output")
plan = _coerce_plan(payload, message, floating)
logger.info(
"deep_agent: planner produced tasks=%s floating=%s",
[t.worker for t in plan.tasks],
plan.floating_domain,
)
return plan
except Exception as exc:
logger.warning("deep_agent: planner failed, using fallback: %s", exc)
return _fallback_plan(message, floating)
def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]:
out: dict[str, list[str]] = {
"task": [],
"project": [],
"note": [],
"timeline": [],
}
table_to_tag = {
"tasks": "task",
"projects": "project",
"notes": "note",
"timelines": "timeline",
}
for item in tool_results:
table = item.get("table")
tag = table_to_tag.get(table)
if tag is None:
continue
payload = item.get("data") or {}
rows: list[dict[str, Any]] = []
row = payload.get("row")
if isinstance(row, dict):
rows.append(row)
if isinstance(payload.get("rows"), list):
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
if isinstance(payload.get("results"), list):
rows.extend([r for r in payload["results"] if isinstance(r, dict)])
for r in rows:
entity_id = r.get("id")
if isinstance(entity_id, str) and entity_id not in out[tag]:
out[tag].append(entity_id)
return out
def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]:
"""Extract small, structured facts for the synthesizer to avoid hallucinations."""
facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []}
for item in tool_results:
table = item.get("table")
payload = item.get("data") or {}
rows: list[dict[str, Any]] = []
row = payload.get("row")
if isinstance(row, dict):
rows.append(row)
if isinstance(payload.get("rows"), list):
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
if table == "projects":
for r in rows:
facts["projects"].append(
{
"id": r.get("id"),
"name": r.get("name"),
"status": r.get("status"),
"clientId": r.get("clientId"),
}
)
elif table == "tasks":
for r in rows:
facts["tasks"].append(
{
"id": r.get("id"),
"title": r.get("title"),
"status": r.get("status"),
"projectId": r.get("projectId"),
}
)
elif table == "notes":
for r in rows:
facts["notes"].append(
{
"id": r.get("id"),
"title": r.get("title"),
"projectId": r.get("projectId"),
}
)
elif table == "timelines":
for r in rows:
facts["timelines"].append(
{
"id": r.get("id"),
"title": r.get("title"),
"date": r.get("date"),
"projectId": r.get("projectId"),
}
)
return facts
async def _run_tool_loop(
worker: WorkerName,
instruction: str,
context: dict[str, Any],
) -> tuple[str, list[dict[str, Any]]]:
worker_prompt = WORKER_CONFIG[worker]["prompt"]
tools = WORKER_CONFIG[worker]["tools"]
llm = get_llm()
llm_with_tools = llm.bind_tools(tools) if tools else llm
resolved_project_id = None
ctx = context.get("context", {}) if isinstance(context, dict) else {}
if isinstance(ctx, dict):
rpid = ctx.get("resolved_project_id")
if isinstance(rpid, str) and rpid:
resolved_project_id = rpid
mandatory_tool_policy = ""
if resolved_project_id:
if worker == "project_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Optionally call list_projects afterward only if needed for disambiguation.\n\n"
)
elif worker == "task_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Do not use project name as project_id.\n\n"
)
elif worker == "timeline_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Do not use project name as project_id.\n\n"
)
elif worker == "note_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Do not use project name as project_id.\n\n"
)
messages: list[Any] = [
SystemMessage(content=worker_prompt),
HumanMessage(
content=(
mandatory_tool_policy +
"Worker instruction:\n"
f"{instruction}\n\n"
"Conversation context:\n"
f"{json.dumps(context, ensure_ascii=True)[:2000]}"
)
),
]
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(6):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content), collected
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s",
worker,
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
tool_output_text = str(tool_output)
logger.info(
"deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s",
worker,
call_id,
call_name,
tool_output_text[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
logger.info(
"deep_agent: worker=%s appended ToolMessage tool_call_id=%s",
worker,
call_id,
)
structured_llm = llm.with_structured_output(WorkerSummary)
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
final_summary = await structured_llm.ainvoke(messages)
if isinstance(final_summary, WorkerSummary):
return final_summary.summary, collected
return str(final_summary), collected
finally:
clear_tool_result_collector()
def _worker_node(worker: WorkerName):
async def _node(state: GraphState) -> AggregatedState:
task_payload = state.get("task") or {}
if task_payload.get("worker") != worker:
return {"worker_results": []}
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240])
worker_context = {
"memory": state.get("memory_context", {}),
"context": state.get("context", {}),
}
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
logger.info(
"deep_agent: worker=%s complete tool_calls=%d entity_counts=%s",
worker,
len(tool_results),
{k: len(v) for k, v in _extract_entity_ids(tool_results).items()},
)
return {
"worker_results": [
{
"worker": worker,
"instruction": instruction,
"response": response,
"entity_ids": _extract_entity_ids(tool_results),
"facts": _extract_facts(tool_results),
}
]
}
return _node
def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
worker_results = state.get("worker_results", [])
formatted_results = []
for result in worker_results:
formatted_results.append(
{
"worker": result.get("worker"),
"instruction": result.get("instruction"),
"response": result.get("response"),
"entity_ids": result.get("entity_ids", {}),
"facts": result.get("facts", {}),
}
)
payload = {
"user_message": state.get("user_message", ""),
"memory_context": state.get("memory_context", {}),
"worker_results": formatted_results,
"floating_domain": state.get("floating_domain") if floating else None,
}
return json.dumps(payload, ensure_ascii=True)
async def _stream_with_memory_tool(
*,
user_id: str,
system_prompt: str,
user_prompt: str,
) -> str:
llm = get_llm()
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
chunks: list[str] = []
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if not token:
continue
chunks.append(token)
return "".join(chunks)
def _synthesizer_node(floating: bool):
async def _node(state: GraphState) -> GraphState:
prompt = _build_synthesis_prompt(state, floating=floating)
system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM
final_response = await _stream_with_memory_tool(
user_id=str(state.get("user_id", "")),
system_prompt=system_prompt,
user_prompt=prompt,
)
return {"final_response": final_response}
return _node
async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]:
if not updates:
return current_memory
new_memory = dict(current_memory)
async with async_session() as db:
memory = MemoryMiddleware(db)
for update in updates:
await memory.update_core(user_id, update.key, update.value)
new_memory[update.key] = update.value
return new_memory
async def _orchestrator_node_home(state: GraphState) -> GraphState:
if state.get("plan"):
return {}
user_message = str(state.get("user_message", ""))
base_context = dict(state.get("context", {}))
context = {**base_context, **state.get("memory_context", {})}
if _needs_full_project_snapshot(user_message, floating=False):
resolved_project_id = await _resolve_project_id_from_message(user_message)
if resolved_project_id:
base_context["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200])
plan = _build_full_project_snapshot_plan(user_message)
else:
plan = await _plan_with_llm(user_message, context, floating=False)
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
return {
"plan": [task.model_dump() for task in plan.tasks],
"memory_context": new_memory,
"context": base_context,
}
async def _orchestrator_node_floating(state: GraphState) -> GraphState:
if state.get("plan"):
return {}
context = {**state.get("context", {}), **state.get("memory_context", {})}
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True)
floating_domain = plan.floating_domain
if floating_domain is None and plan.tasks:
floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
return {
"plan": [task.model_dump() for task in plan.tasks],
"floating_domain": floating_domain or "tasks",
"memory_context": new_memory
}
def _route_workers(state: GraphState) -> list[Send] | str:
plan = state.get("plan", [])
if not plan:
return "synthesizer"
sends: list[Send] = []
for task in plan:
worker = task.get("worker")
if worker in WORKER_CONFIG:
sends.append(Send(worker, {"task": task}))
return sends or "synthesizer"
def _build_graph(*, floating: bool):
builder = StateGraph(GraphState)
orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home
builder.add_node("orchestrator", orchestrator_node)
for worker in WORKER_CONFIG:
builder.add_node(worker, _worker_node(worker))
builder.add_node("synthesizer", _synthesizer_node(floating=floating))
builder.add_edge(START, "orchestrator")
builder.add_conditional_edges(
"orchestrator",
_route_workers,
["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"],
)
for worker in WORKER_CONFIG:
builder.add_edge(worker, "synthesizer")
builder.add_edge("synthesizer", END)
return builder.compile()
HOME_GRAPH = _build_graph(floating=False)
FLOATING_GRAPH = _build_graph(floating=True)
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
if HOME_SINGLE_AGENT_TEST_MODE:
return await _run_home_single_agent(user_id, message, context)
state = await HOME_GRAPH.ainvoke(
{
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"worker_results": [],
}
)
return str(state.get("final_response", ""))
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
plan = await _plan_with_llm(message, context, floating=True)
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
state = await FLOATING_GRAPH.ainvoke(
{
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": new_memory,
"plan": [task.model_dump() for task in plan.tasks],
"floating_domain": domain,
"worker_results": [],
}
)
return str(state.get("final_response", "")), str(domain)
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
if HOME_SINGLE_AGENT_TEST_MODE:
async for event in _run_home_single_agent_stream(user_id, message, context):
yield event
return
state_input = {
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": context,
"worker_results": [],
}
async for event in HOME_GRAPH.astream_events(state_input, version="v2"):
kind = event["event"]
if kind == "on_chat_model_stream":
node_name = event.get("metadata", {}).get("langgraph_node")
if node_name == "synthesizer":
chunk = event["data"]["chunk"]
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
plan = await _plan_with_llm(message, context, floating=True)
domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
yield "floating_domain", domain
new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
state_input = {
"user_id": user_id,
"user_message": message,
"context": context,
"memory_context": new_memory,
"plan": [t.model_dump() for t in plan.tasks],
"floating_domain": domain,
"worker_results": [],
}
async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"):
kind = event["event"]
if kind == "on_chat_model_stream":
node_name = event.get("metadata", {}).get("langgraph_node")
if node_name == "synthesizer":
chunk = event["data"]["chunk"]
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token