"""Deep orchestrator-worker graphs for home and floating chat contexts.""" from __future__ import annotations import asyncio import json import logging import re from collections.abc import AsyncGenerator, Awaitable, Callable import operator from typing import Annotated, Any, Literal, TypedDict from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.tools import tool from langgraph.constants import END, START from langgraph.graph import StateGraph from langgraph.types import Send from pydantic import BaseModel, Field from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS from app.core.llm import get_llm from app.core.memory_middleware import MemoryMiddleware from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector from app.db import async_session logger = logging.getLogger(__name__) # Quick test switch: home requests run as one agent with all tools. HOME_SINGLE_AGENT_TEST_MODE = True WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"] FloatingDomain = Literal["tasks", "projects", "notes", "timelines"] class WorkerTask(BaseModel): worker: WorkerName instruction: str class MemoryUpdate(BaseModel): key: str = Field(description="The memory key to set or update.") value: str = Field(description="The persistent fact or preference value.") class WorkerSummary(BaseModel): summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.") class WorkerPlan(BaseModel): tasks: list[WorkerTask] = Field(default_factory=list) floating_domain: FloatingDomain | None = None memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.") class WorkerResult(TypedDict): worker: WorkerName instruction: str response: str entity_ids: dict[str, list[str]] facts: dict[str, Any] class OrchestratorState(TypedDict, total=False): user_id: str user_message: str context: dict[str, Any] memory_context: dict[str, Any] plan: list[dict[str, Any]] floating_domain: FloatingDomain task: dict[str, Any] worker_results: list[WorkerResult] final_response: str class GraphState(OrchestratorState): worker_results: Annotated[list[WorkerResult], operator.add] class ReducerState(OrchestratorState): worker_results: list[WorkerResult] class AggregatedState(TypedDict, total=False): worker_results: list[WorkerResult] WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = { "task_agent": { "prompt": TASK_SYSTEM_PROMPT, "tools": TASK_TOOLS, "tag": "task", "table": "tasks", "floating_domain": "tasks", }, "project_agent": { "prompt": PROJECT_SYSTEM_PROMPT, "tools": PROJECT_TOOLS, "tag": "project", "table": "projects", "floating_domain": "projects", }, "note_agent": { "prompt": NOTE_SYSTEM_PROMPT, "tools": NOTE_TOOLS, "tag": "note", "table": "notes", "floating_domain": "notes", }, "timeline_agent": { "prompt": TIMELINE_SYSTEM_PROMPT, "tools": TIMELINE_TOOLS, "tag": "timeline", "table": "timelines", "floating_domain": "timelines", }, } _HOME_ORCHESTRATOR_SYSTEM = ( "You are an orchestrator. Plan which workers should be invoked for the user request. " "Workers: task_agent, project_agent, note_agent, timeline_agent. " "Return JSON only with keys: tasks, floating_domain, memory_updates." ) _FLOATING_ORCHESTRATOR_SYSTEM = ( "You are an orchestrator for floating context. Pick focused workers and set floating_domain " "as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates." ) _HOME_SYNTH_SYSTEM = ( "You are the final response synthesizer. Return markdown only. " "Embed inline component tags when relevant: [ids], [ids], " "[ids], [ids], and {json}. " "Only include IDs that are truly relevant to the request. " "Never invent missing values. If facts include a non-null clientId for a project, " "do not claim that the project has no owner/client." ) _FLOATING_SYNTH_SYSTEM = ( "You are the final response synthesizer for floating UI context. " "Return concise markdown and stay focused on the requested scope." ) _HOME_SINGLE_AGENT_SYSTEM = ( "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. " "Always use tools for factual data retrieval before answering. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " "Return markdown and embed inline tags when relevant: [ids], [ids], " "[ids], [ids], {json}." ) def _as_text(content: Any) -> str: if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict): text = item.get("text") if isinstance(text, str): parts.append(text) return "".join(parts) return str(content) def _fallback_plan(message: str, floating: bool) -> WorkerPlan: lowered = message.lower() tasks: list[WorkerTask] = [] if any(k in lowered for k in ["task", "todo", "deadline", "due"]): tasks.append(WorkerTask(worker="task_agent", instruction=message)) if any(k in lowered for k in ["project", "client", "milestone"]): tasks.append(WorkerTask(worker="project_agent", instruction=message)) if any(k in lowered for k in ["note", "document", "memo"]): tasks.append(WorkerTask(worker="note_agent", instruction=message)) if any(k in lowered for k in ["timeline", "event", "schedule", "release"]): tasks.append(WorkerTask(worker="timeline_agent", instruction=message)) if not tasks: tasks = [WorkerTask(worker="task_agent", instruction=message)] domain: FloatingDomain | None = None if floating: domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"] return WorkerPlan(tasks=tasks, floating_domain=domain) def _extract_json_object(text: str) -> dict[str, Any] | None: """Best-effort extraction of the first JSON object from model output.""" stripped = text.strip() if not stripped: return None # Common case: model returns raw JSON object. try: payload = json.loads(stripped) if isinstance(payload, dict): return payload except json.JSONDecodeError: pass # Fenced JSON block fallback. if "```" in stripped: parts = stripped.split("```") for part in parts: candidate = part.strip() if candidate.startswith("json"): candidate = candidate[4:].strip() try: payload = json.loads(candidate) if isinstance(payload, dict): return payload except json.JSONDecodeError: continue return None def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan: """Normalize loose model JSON into a validated WorkerPlan.""" tasks_raw = payload.get("tasks") tasks: list[WorkerTask] = [] if isinstance(tasks_raw, list): for item in tasks_raw: if not isinstance(item, dict): continue worker = item.get("worker") instruction = item.get("instruction") if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str): tasks.append(WorkerTask(worker=worker, instruction=instruction)) if not tasks: return _fallback_plan(message, floating) domain = payload.get("floating_domain") floating_domain: FloatingDomain | None = None if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}: floating_domain = domain # type: ignore[assignment] elif floating: floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"] memory_updates: list[MemoryUpdate] = [] updates_raw = payload.get("memory_updates") if isinstance(updates_raw, list): for item in updates_raw: if isinstance(item, dict): key = item.get("key") value = item.get("value") if isinstance(key, str) and isinstance(value, str) and key and value: memory_updates.append(MemoryUpdate(key=key, value=value)) return WorkerPlan( tasks=tasks, floating_domain=floating_domain, memory_updates=memory_updates, ) def _needs_full_project_snapshot(message: str, floating: bool) -> bool: """Detect project status/update requests that should query all workers.""" if floating: return False lowered = message.lower() has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"]) has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"]) return has_project and has_status_intent def _build_full_project_snapshot_plan(message: str) -> WorkerPlan: """Build a deterministic all-workers plan for project status snapshots.""" project_hint = ( "Use context.context.resolved_project_id when present as project_id. " "Do not pass project names as project_id." ) return WorkerPlan( tasks=[ WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"), WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"), WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"), WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"), ] ) def _candidate_tokens(message: str) -> list[str]: tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower()) return [t for t in tokens if len(t) >= 3] async def _resolve_project_id_from_message(message: str) -> str | None: """Resolve likely project UUID from user message using client project list.""" try: result = await execute_on_client(action="select", table="projects") except Exception as exc: logger.warning("deep_agent: project resolve select failed: %s", exc) return None rows = result.get("rows", []) if not isinstance(rows, list) or not rows: return None tokens = _candidate_tokens(message) scored: list[tuple[int, dict[str, Any]]] = [] for row in rows: if not isinstance(row, dict): continue name = str(row.get("name", "")).lower() score = sum(1 for token in tokens if token in name) if score > 0: scored.append((score, row)) if not scored: return None scored.sort(key=lambda item: item[0], reverse=True) top_score = scored[0][0] top_rows = [row for score, row in scored if score == top_score] if len(top_rows) != 1: return None project_id = top_rows[0].get("id") return project_id if isinstance(project_id, str) else None async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]: """Resolve and inject project_id hints for home flows.""" prepared = dict(context) if _needs_full_project_snapshot(message, floating=False): resolved_project_id = await _resolve_project_id_from_message(message) if resolved_project_id: prepared["resolved_project_id"] = resolved_project_id logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200]) return prepared def _all_tools() -> list[Any]: tools: list[Any] = [] for config in WORKER_CONFIG.values(): tools.extend(config["tools"]) return tools async def _run_home_single_agent( user_id: str, message: str, context: dict[str, Any], ) -> str: """Single-agent test mode: one loop with all tools.""" prepared_context = await _prepare_home_context(message, context) llm = get_llm() tools = _all_tools() llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM), HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"), ] for _ in range(6): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: return _as_text(response.content) tool_map = {t.name: t for t in tools} for call in response.tool_calls: tool_fn = tool_map.get(call["name"]) if tool_fn is None: tool_output = f"Unknown tool: {call['name']}" else: tool_output = await tool_fn.ainvoke(call.get("args", {})) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) final = await llm.ainvoke(messages) return _as_text(final.content) async def _run_home_single_agent_stream( user_id: str, message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: """Streaming variant for single-agent home test mode.""" prepared_context = await _prepare_home_context(message, context) llm = get_llm() tools = _all_tools() llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM), HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"), ] for _ in range(6): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: async for chunk in llm.astream(messages): token = _as_text(getattr(chunk, "content", "")) if token: yield "token", token return tool_map = {t.name: t for t in tools} for call in response.tool_calls: tool_fn = tool_map.get(call["name"]) if tool_fn is None: tool_output = f"Unknown tool: {call['name']}" else: tool_output = await tool_fn.ainvoke(call.get("args", {})) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) async for chunk in llm.astream(messages): token = _as_text(getattr(chunk, "content", "")) if token: yield "token", token async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan: if _needs_full_project_snapshot(message, floating): logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200]) return _build_full_project_snapshot_plan(message) llm = get_llm() system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM prompt_payload = { "message": message, "context": context, "workers": list(WORKER_CONFIG.keys()), } messages = [ SystemMessage(content=system), HumanMessage( content=( "Create a valid JSON object with this exact structure:\n" '{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],' '"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n' "Rules:\n" "- tasks must include at least one entry when possible\n" "- use floating_domain only when relevant\n" "- output JSON only (no markdown, no prose)\n\n" f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}" ) ), ] try: response = await llm.ainvoke(messages) payload = _extract_json_object(_as_text(response.content)) if payload is None: raise ValueError("planner returned non-JSON output") plan = _coerce_plan(payload, message, floating) logger.info( "deep_agent: planner produced tasks=%s floating=%s", [t.worker for t in plan.tasks], plan.floating_domain, ) return plan except Exception as exc: logger.warning("deep_agent: planner failed, using fallback: %s", exc) return _fallback_plan(message, floating) def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]: out: dict[str, list[str]] = { "task": [], "project": [], "note": [], "timeline": [], } table_to_tag = { "tasks": "task", "projects": "project", "notes": "note", "timelines": "timeline", } for item in tool_results: table = item.get("table") tag = table_to_tag.get(table) if tag is None: continue payload = item.get("data") or {} rows: list[dict[str, Any]] = [] row = payload.get("row") if isinstance(row, dict): rows.append(row) if isinstance(payload.get("rows"), list): rows.extend([r for r in payload["rows"] if isinstance(r, dict)]) if isinstance(payload.get("results"), list): rows.extend([r for r in payload["results"] if isinstance(r, dict)]) for r in rows: entity_id = r.get("id") if isinstance(entity_id, str) and entity_id not in out[tag]: out[tag].append(entity_id) return out def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]: """Extract small, structured facts for the synthesizer to avoid hallucinations.""" facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []} for item in tool_results: table = item.get("table") payload = item.get("data") or {} rows: list[dict[str, Any]] = [] row = payload.get("row") if isinstance(row, dict): rows.append(row) if isinstance(payload.get("rows"), list): rows.extend([r for r in payload["rows"] if isinstance(r, dict)]) if table == "projects": for r in rows: facts["projects"].append( { "id": r.get("id"), "name": r.get("name"), "status": r.get("status"), "clientId": r.get("clientId"), } ) elif table == "tasks": for r in rows: facts["tasks"].append( { "id": r.get("id"), "title": r.get("title"), "status": r.get("status"), "projectId": r.get("projectId"), } ) elif table == "notes": for r in rows: facts["notes"].append( { "id": r.get("id"), "title": r.get("title"), "projectId": r.get("projectId"), } ) elif table == "timelines": for r in rows: facts["timelines"].append( { "id": r.get("id"), "title": r.get("title"), "date": r.get("date"), "projectId": r.get("projectId"), } ) return facts async def _run_tool_loop( worker: WorkerName, instruction: str, context: dict[str, Any], ) -> tuple[str, list[dict[str, Any]]]: worker_prompt = WORKER_CONFIG[worker]["prompt"] tools = WORKER_CONFIG[worker]["tools"] llm = get_llm() llm_with_tools = llm.bind_tools(tools) if tools else llm resolved_project_id = None ctx = context.get("context", {}) if isinstance(context, dict) else {} if isinstance(ctx, dict): rpid = ctx.get("resolved_project_id") if isinstance(rpid, str) and rpid: resolved_project_id = rpid mandatory_tool_policy = "" if resolved_project_id: if worker == "project_agent": mandatory_tool_policy = ( "MANDATORY TOOL POLICY:\n" f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n" "- Optionally call list_projects afterward only if needed for disambiguation.\n\n" ) elif worker == "task_agent": mandatory_tool_policy = ( "MANDATORY TOOL POLICY:\n" f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n" "- Do not use project name as project_id.\n\n" ) elif worker == "timeline_agent": mandatory_tool_policy = ( "MANDATORY TOOL POLICY:\n" f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n" "- Do not use project name as project_id.\n\n" ) elif worker == "note_agent": mandatory_tool_policy = ( "MANDATORY TOOL POLICY:\n" f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n" "- Do not use project name as project_id.\n\n" ) messages: list[Any] = [ SystemMessage(content=worker_prompt), HumanMessage( content=( mandatory_tool_policy + "Worker instruction:\n" f"{instruction}\n\n" "Conversation context:\n" f"{json.dumps(context, ensure_ascii=True)[:2000]}" ) ), ] collected: list[dict[str, Any]] = [] set_tool_result_collector(collected) try: for _ in range(6): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: return _as_text(response.content), collected tool_map = {t.name: t for t in tools} for call in response.tool_calls: call_id = str(call.get("id", "")) call_name = str(call.get("name", "")) call_args = call.get("args", {}) logger.info( "deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s", worker, call_id, call_name, json.dumps(call_args, ensure_ascii=True)[:800], ) tool_fn = tool_map.get(call["name"]) if tool_fn is None: tool_output = f"Unknown tool: {call['name']}" else: tool_output = await tool_fn.ainvoke(call.get("args", {})) tool_output_text = str(tool_output) logger.info( "deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s", worker, call_id, call_name, tool_output_text[:1200], ) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) logger.info( "deep_agent: worker=%s appended ToolMessage tool_call_id=%s", worker, call_id, ) structured_llm = llm.with_structured_output(WorkerSummary) messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences.")) final_summary = await structured_llm.ainvoke(messages) if isinstance(final_summary, WorkerSummary): return final_summary.summary, collected return str(final_summary), collected finally: clear_tool_result_collector() def _worker_node(worker: WorkerName): async def _node(state: GraphState) -> AggregatedState: task_payload = state.get("task") or {} if task_payload.get("worker") != worker: return {"worker_results": []} instruction = str(task_payload.get("instruction") or state.get("user_message") or "") logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240]) worker_context = { "memory": state.get("memory_context", {}), "context": state.get("context", {}), } response, tool_results = await _run_tool_loop(worker, instruction, worker_context) logger.info( "deep_agent: worker=%s complete tool_calls=%d entity_counts=%s", worker, len(tool_results), {k: len(v) for k, v in _extract_entity_ids(tool_results).items()}, ) return { "worker_results": [ { "worker": worker, "instruction": instruction, "response": response, "entity_ids": _extract_entity_ids(tool_results), "facts": _extract_facts(tool_results), } ] } return _node def _build_synthesis_prompt(state: GraphState, floating: bool) -> str: worker_results = state.get("worker_results", []) formatted_results = [] for result in worker_results: formatted_results.append( { "worker": result.get("worker"), "instruction": result.get("instruction"), "response": result.get("response"), "entity_ids": result.get("entity_ids", {}), "facts": result.get("facts", {}), } ) payload = { "user_message": state.get("user_message", ""), "memory_context": state.get("memory_context", {}), "worker_results": formatted_results, "floating_domain": state.get("floating_domain") if floating else None, } return json.dumps(payload, ensure_ascii=True) async def _stream_with_memory_tool( *, user_id: str, system_prompt: str, user_prompt: str, ) -> str: llm = get_llm() messages: list[Any] = [ SystemMessage(content=system_prompt), HumanMessage(content=user_prompt), ] chunks: list[str] = [] async for chunk in llm.astream(messages): token = _as_text(getattr(chunk, "content", "")) if not token: continue chunks.append(token) return "".join(chunks) def _synthesizer_node(floating: bool): async def _node(state: GraphState) -> GraphState: prompt = _build_synthesis_prompt(state, floating=floating) system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM final_response = await _stream_with_memory_tool( user_id=str(state.get("user_id", "")), system_prompt=system_prompt, user_prompt=prompt, ) return {"final_response": final_response} return _node async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]: if not updates: return current_memory new_memory = dict(current_memory) async with async_session() as db: memory = MemoryMiddleware(db) for update in updates: await memory.update_core(user_id, update.key, update.value) new_memory[update.key] = update.value return new_memory async def _orchestrator_node_home(state: GraphState) -> GraphState: if state.get("plan"): return {} user_message = str(state.get("user_message", "")) base_context = dict(state.get("context", {})) context = {**base_context, **state.get("memory_context", {})} if _needs_full_project_snapshot(user_message, floating=False): resolved_project_id = await _resolve_project_id_from_message(user_message) if resolved_project_id: base_context["resolved_project_id"] = resolved_project_id logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200]) plan = _build_full_project_snapshot_plan(user_message) else: plan = await _plan_with_llm(user_message, context, floating=False) new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {})) return { "plan": [task.model_dump() for task in plan.tasks], "memory_context": new_memory, "context": base_context, } async def _orchestrator_node_floating(state: GraphState) -> GraphState: if state.get("plan"): return {} context = {**state.get("context", {}), **state.get("memory_context", {})} plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True) floating_domain = plan.floating_domain if floating_domain is None and plan.tasks: floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {})) return { "plan": [task.model_dump() for task in plan.tasks], "floating_domain": floating_domain or "tasks", "memory_context": new_memory } def _route_workers(state: GraphState) -> list[Send] | str: plan = state.get("plan", []) if not plan: return "synthesizer" sends: list[Send] = [] for task in plan: worker = task.get("worker") if worker in WORKER_CONFIG: sends.append(Send(worker, {"task": task})) return sends or "synthesizer" def _build_graph(*, floating: bool): builder = StateGraph(GraphState) orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home builder.add_node("orchestrator", orchestrator_node) for worker in WORKER_CONFIG: builder.add_node(worker, _worker_node(worker)) builder.add_node("synthesizer", _synthesizer_node(floating=floating)) builder.add_edge(START, "orchestrator") builder.add_conditional_edges( "orchestrator", _route_workers, ["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"], ) for worker in WORKER_CONFIG: builder.add_edge(worker, "synthesizer") builder.add_edge("synthesizer", END) return builder.compile() HOME_GRAPH = _build_graph(floating=False) FLOATING_GRAPH = _build_graph(floating=True) async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: if HOME_SINGLE_AGENT_TEST_MODE: return await _run_home_single_agent(user_id, message, context) state = await HOME_GRAPH.ainvoke( { "user_id": user_id, "user_message": message, "context": context, "memory_context": context, "worker_results": [], } ) return str(state.get("final_response", "")) async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: plan = await _plan_with_llm(message, context, floating=True) domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context) state = await FLOATING_GRAPH.ainvoke( { "user_id": user_id, "user_message": message, "context": context, "memory_context": new_memory, "plan": [task.model_dump() for task in plan.tasks], "floating_domain": domain, "worker_results": [], } ) return str(state.get("final_response", "")), str(domain) async def run_home_stream( user_id: str, message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: if HOME_SINGLE_AGENT_TEST_MODE: async for event in _run_home_single_agent_stream(user_id, message, context): yield event return state_input = { "user_id": user_id, "user_message": message, "context": context, "memory_context": context, "worker_results": [], } async for event in HOME_GRAPH.astream_events(state_input, version="v2"): kind = event["event"] if kind == "on_chat_model_stream": node_name = event.get("metadata", {}).get("langgraph_node") if node_name == "synthesizer": chunk = event["data"]["chunk"] token = _as_text(getattr(chunk, "content", "")) if token: yield "token", token async def run_floating_stream( user_id: str, message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: plan = await _plan_with_llm(message, context, floating=True) domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] yield "floating_domain", domain new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context) state_input = { "user_id": user_id, "user_message": message, "context": context, "memory_context": new_memory, "plan": [t.model_dump() for t in plan.tasks], "floating_domain": domain, "worker_results": [], } async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"): kind = event["event"] if kind == "on_chat_model_stream": node_name = event.get("metadata", {}).get("langgraph_node") if node_name == "synthesizer": chunk = event["data"]["chunk"] token = _as_text(getattr(chunk, "content", "")) if token: yield "token", token