diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py
index 52f5166..22559a4 100644
--- a/app/core/deep_agent.py
+++ b/app/core/deep_agent.py
@@ -1,26 +1,19 @@
-"""Deep orchestrator-worker graphs for home and floating chat contexts."""
+"""Single-agent runners for home and floating chat contexts."""
from __future__ import annotations
-import asyncio
import json
import logging
import re
-from collections.abc import AsyncGenerator, Awaitable, Callable
-import operator
-from typing import Annotated, Any, Literal, TypedDict
+from collections.abc import AsyncGenerator
+from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
-from langchain_core.tools import tool
-from langgraph.constants import END, START
-from langgraph.graph import StateGraph
-from langgraph.types import Send
-from pydantic import BaseModel, Field
-from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS
-from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS
-from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
-from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
+from app.agents.note_agent import NOTE_TOOLS
+from app.agents.project_agent import PROJECT_TOOLS
+from app.agents.task_agent import TASK_TOOLS
+from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
@@ -28,121 +21,8 @@ from app.db import async_session
logger = logging.getLogger(__name__)
-# Quick test switch: home requests run as one agent with all tools.
-HOME_SINGLE_AGENT_TEST_MODE = True
-
-WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
-
-class WorkerTask(BaseModel):
- worker: WorkerName
- instruction: str
-
-
-class MemoryUpdate(BaseModel):
- key: str = Field(description="The memory key to set or update.")
- value: str = Field(description="The persistent fact or preference value.")
-
-
-class WorkerSummary(BaseModel):
- summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.")
-
-
-class WorkerPlan(BaseModel):
- tasks: list[WorkerTask] = Field(default_factory=list)
- floating_domain: FloatingDomain | None = None
- memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.")
-
-
-class WorkerResult(TypedDict):
- worker: WorkerName
- instruction: str
- response: str
- entity_ids: dict[str, list[str]]
- facts: dict[str, Any]
-
-
-class OrchestratorState(TypedDict, total=False):
- user_id: str
- user_message: str
- context: dict[str, Any]
- memory_context: dict[str, Any]
- plan: list[dict[str, Any]]
- floating_domain: FloatingDomain
- task: dict[str, Any]
- worker_results: list[WorkerResult]
- final_response: str
-
-
-class GraphState(OrchestratorState):
- worker_results: Annotated[list[WorkerResult], operator.add]
-
-
-class ReducerState(OrchestratorState):
- worker_results: list[WorkerResult]
-
-
-class AggregatedState(TypedDict, total=False):
- worker_results: list[WorkerResult]
-
-
-WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
- "task_agent": {
- "prompt": TASK_SYSTEM_PROMPT,
- "tools": TASK_TOOLS,
- "tag": "task",
- "table": "tasks",
- "floating_domain": "tasks",
- },
- "project_agent": {
- "prompt": PROJECT_SYSTEM_PROMPT,
- "tools": PROJECT_TOOLS,
- "tag": "project",
- "table": "projects",
- "floating_domain": "projects",
- },
- "note_agent": {
- "prompt": NOTE_SYSTEM_PROMPT,
- "tools": NOTE_TOOLS,
- "tag": "note",
- "table": "notes",
- "floating_domain": "notes",
- },
- "timeline_agent": {
- "prompt": TIMELINE_SYSTEM_PROMPT,
- "tools": TIMELINE_TOOLS,
- "tag": "timeline",
- "table": "timelines",
- "floating_domain": "timelines",
- },
-}
-
-_HOME_ORCHESTRATOR_SYSTEM = (
- "You are an orchestrator. Plan which workers should be invoked for the user request. "
- "Workers: task_agent, project_agent, note_agent, timeline_agent. "
- "Return JSON only with keys: tasks, floating_domain, memory_updates."
-)
-
-_FLOATING_ORCHESTRATOR_SYSTEM = (
- "You are an orchestrator for floating context. Pick focused workers and set floating_domain "
- "as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
-)
-
-_HOME_SYNTH_SYSTEM = (
- "You are the final response synthesizer. Return markdown only. "
- "Embed inline component tags when relevant: [ids], [ids], "
- "[ids], [ids], and {json}. "
- "Only include IDs that are truly relevant to the request. "
- "Never invent missing values. If facts include a non-null clientId for a project, "
- "do not claim that the project has no owner/client."
-)
-
-_FLOATING_SYNTH_SYSTEM = (
- "You are the final response synthesizer for floating UI context. "
- "Return concise markdown and stay focused on the requested scope."
-)
-
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
"Always use tools for factual data retrieval before answering. "
@@ -151,6 +31,15 @@ _HOME_SINGLE_AGENT_SYSTEM = (
"[ids], [ids], {json}."
)
+_FLOATING_SINGLE_AGENT_SYSTEM = (
+ "You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. "
+ "Stay focused on the floating scope in context.scope and answer concisely. "
+ "Always use tools for factual data retrieval before answering. "
+ "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
+ "Return markdown and embed inline tags when relevant: [ids], [ids], "
+ "[ids], [ids], {json}."
+)
+
def _as_text(content: Any) -> str:
if content is None:
@@ -170,130 +59,9 @@ def _as_text(content: Any) -> str:
return str(content)
-def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
- lowered = message.lower()
- tasks: list[WorkerTask] = []
-
- if any(k in lowered for k in ["task", "todo", "deadline", "due"]):
- tasks.append(WorkerTask(worker="task_agent", instruction=message))
- if any(k in lowered for k in ["project", "client", "milestone"]):
- tasks.append(WorkerTask(worker="project_agent", instruction=message))
- if any(k in lowered for k in ["note", "document", "memo"]):
- tasks.append(WorkerTask(worker="note_agent", instruction=message))
- if any(k in lowered for k in ["timeline", "event", "schedule", "release"]):
- tasks.append(WorkerTask(worker="timeline_agent", instruction=message))
-
- if not tasks:
- tasks = [WorkerTask(worker="task_agent", instruction=message)]
-
- domain: FloatingDomain | None = None
- if floating:
- domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
-
- return WorkerPlan(tasks=tasks, floating_domain=domain)
-
-
-def _extract_json_object(text: str) -> dict[str, Any] | None:
- """Best-effort extraction of the first JSON object from model output."""
- stripped = text.strip()
- if not stripped:
- return None
-
- # Common case: model returns raw JSON object.
- try:
- payload = json.loads(stripped)
- if isinstance(payload, dict):
- return payload
- except json.JSONDecodeError:
- pass
-
- # Fenced JSON block fallback.
- if "```" in stripped:
- parts = stripped.split("```")
- for part in parts:
- candidate = part.strip()
- if candidate.startswith("json"):
- candidate = candidate[4:].strip()
- try:
- payload = json.loads(candidate)
- if isinstance(payload, dict):
- return payload
- except json.JSONDecodeError:
- continue
-
- return None
-
-
-def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
- """Normalize loose model JSON into a validated WorkerPlan."""
- tasks_raw = payload.get("tasks")
- tasks: list[WorkerTask] = []
-
- if isinstance(tasks_raw, list):
- for item in tasks_raw:
- if not isinstance(item, dict):
- continue
- worker = item.get("worker")
- instruction = item.get("instruction")
- if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
- tasks.append(WorkerTask(worker=worker, instruction=instruction))
-
- if not tasks:
- return _fallback_plan(message, floating)
-
- domain = payload.get("floating_domain")
- floating_domain: FloatingDomain | None = None
- if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
- floating_domain = domain # type: ignore[assignment]
- elif floating:
- floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
-
- memory_updates: list[MemoryUpdate] = []
- updates_raw = payload.get("memory_updates")
- if isinstance(updates_raw, list):
- for item in updates_raw:
- if isinstance(item, dict):
- key = item.get("key")
- value = item.get("value")
- if isinstance(key, str) and isinstance(value, str) and key and value:
- memory_updates.append(MemoryUpdate(key=key, value=value))
-
- return WorkerPlan(
- tasks=tasks,
- floating_domain=floating_domain,
- memory_updates=memory_updates,
- )
-
-
-def _needs_full_project_snapshot(message: str, floating: bool) -> bool:
- """Detect project status/update requests that should query all workers."""
- if floating:
- return False
- lowered = message.lower()
- has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"])
- has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"])
- return has_project and has_status_intent
-
-
-def _build_full_project_snapshot_plan(message: str) -> WorkerPlan:
- """Build a deterministic all-workers plan for project status snapshots."""
- project_hint = (
- "Use context.context.resolved_project_id when present as project_id. "
- "Do not pass project names as project_id."
- )
- return WorkerPlan(
- tasks=[
- WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"),
- WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"),
- WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"),
- WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"),
- ]
- )
-
-
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
- return [t for t in tokens if len(t) >= 3]
+ return [token for token in tokens if len(token) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
@@ -331,297 +99,64 @@ async def _resolve_project_id_from_message(message: str) -> str | None:
return project_id if isinstance(project_id, str) else None
-async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
- """Resolve and inject project_id hints for home flows."""
+def _needs_project_resolution(message: str) -> bool:
+ lowered = message.lower()
+ return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
+
+
+async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
prepared = dict(context)
- if _needs_full_project_snapshot(message, floating=False):
+ if _needs_project_resolution(message):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
- logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200])
+ logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
return prepared
def _all_tools() -> list[Any]:
- tools: list[Any] = []
- for config in WORKER_CONFIG.values():
- tools.extend(config["tools"])
- return tools
+ return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
-async def _run_home_single_agent(
- user_id: str,
+def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
+ scope = context.get("scope") if isinstance(context, dict) else None
+ if isinstance(scope, dict):
+ scope_type = str(scope.get("type") or "").strip().lower()
+ if scope_type in {"task", "tasks"}:
+ return "tasks"
+ if scope_type in {"project", "projects"}:
+ return "projects"
+ if scope_type in {"note", "notes"}:
+ return "notes"
+ if scope_type in {"timeline", "timelines"}:
+ return "timelines"
+
+ lowered = message.lower()
+ if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
+ return "timelines"
+ if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
+ return "notes"
+ if any(keyword in lowered for keyword in ["project", "progetto", "client"]):
+ return "projects"
+ return "tasks"
+
+
+async def _run_single_agent(
+ *,
+ system_prompt: str,
message: str,
context: dict[str, Any],
+ max_steps: int = 6,
) -> str:
- """Single-agent test mode: one loop with all tools."""
- prepared_context = await _prepare_home_context(message, context)
-
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
- SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
- HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
- ]
-
- for _ in range(6):
- response: AIMessage = await llm_with_tools.ainvoke(messages)
- messages.append(response)
- if not response.tool_calls:
- return _as_text(response.content)
-
- tool_map = {t.name: t for t in tools}
- for call in response.tool_calls:
- tool_fn = tool_map.get(call["name"])
- if tool_fn is None:
- tool_output = f"Unknown tool: {call['name']}"
- else:
- tool_output = await tool_fn.ainvoke(call.get("args", {}))
- messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
-
- final = await llm.ainvoke(messages)
- return _as_text(final.content)
-
-
-async def _run_home_single_agent_stream(
- user_id: str,
- message: str,
- context: dict[str, Any],
-) -> AsyncGenerator[tuple[str, Any], None]:
- """Streaming variant for single-agent home test mode."""
- prepared_context = await _prepare_home_context(message, context)
-
- llm = get_llm()
- tools = _all_tools()
- llm_with_tools = llm.bind_tools(tools)
- messages: list[Any] = [
- SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
- HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
- ]
-
- for _ in range(6):
- response: AIMessage = await llm_with_tools.ainvoke(messages)
- messages.append(response)
- if not response.tool_calls:
- async for chunk in llm.astream(messages):
- token = _as_text(getattr(chunk, "content", ""))
- if token:
- yield "token", token
- return
-
- tool_map = {t.name: t for t in tools}
- for call in response.tool_calls:
- tool_fn = tool_map.get(call["name"])
- if tool_fn is None:
- tool_output = f"Unknown tool: {call['name']}"
- else:
- tool_output = await tool_fn.ainvoke(call.get("args", {}))
- messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
-
- async for chunk in llm.astream(messages):
- token = _as_text(getattr(chunk, "content", ""))
- if token:
- yield "token", token
-
-
-async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
- if _needs_full_project_snapshot(message, floating):
- logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200])
- return _build_full_project_snapshot_plan(message)
-
- llm = get_llm()
- system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
-
- prompt_payload = {
- "message": message,
- "context": context,
- "workers": list(WORKER_CONFIG.keys()),
- }
- messages = [
- SystemMessage(content=system),
+ SystemMessage(content=system_prompt),
HumanMessage(
content=(
- "Create a valid JSON object with this exact structure:\n"
- '{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
- '"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
- "Rules:\n"
- "- tasks must include at least one entry when possible\n"
- "- use floating_domain only when relevant\n"
- "- output JSON only (no markdown, no prose)\n\n"
- f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
- )
- ),
- ]
-
- try:
- response = await llm.ainvoke(messages)
- payload = _extract_json_object(_as_text(response.content))
- if payload is None:
- raise ValueError("planner returned non-JSON output")
- plan = _coerce_plan(payload, message, floating)
- logger.info(
- "deep_agent: planner produced tasks=%s floating=%s",
- [t.worker for t in plan.tasks],
- plan.floating_domain,
- )
- return plan
- except Exception as exc:
- logger.warning("deep_agent: planner failed, using fallback: %s", exc)
-
- return _fallback_plan(message, floating)
-
-
-def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]:
- out: dict[str, list[str]] = {
- "task": [],
- "project": [],
- "note": [],
- "timeline": [],
- }
- table_to_tag = {
- "tasks": "task",
- "projects": "project",
- "notes": "note",
- "timelines": "timeline",
- }
-
- for item in tool_results:
- table = item.get("table")
- tag = table_to_tag.get(table)
- if tag is None:
- continue
-
- payload = item.get("data") or {}
- rows: list[dict[str, Any]] = []
- row = payload.get("row")
- if isinstance(row, dict):
- rows.append(row)
- if isinstance(payload.get("rows"), list):
- rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
- if isinstance(payload.get("results"), list):
- rows.extend([r for r in payload["results"] if isinstance(r, dict)])
-
- for r in rows:
- entity_id = r.get("id")
- if isinstance(entity_id, str) and entity_id not in out[tag]:
- out[tag].append(entity_id)
-
- return out
-
-
-def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]:
- """Extract small, structured facts for the synthesizer to avoid hallucinations."""
- facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []}
-
- for item in tool_results:
- table = item.get("table")
- payload = item.get("data") or {}
-
- rows: list[dict[str, Any]] = []
- row = payload.get("row")
- if isinstance(row, dict):
- rows.append(row)
- if isinstance(payload.get("rows"), list):
- rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
-
- if table == "projects":
- for r in rows:
- facts["projects"].append(
- {
- "id": r.get("id"),
- "name": r.get("name"),
- "status": r.get("status"),
- "clientId": r.get("clientId"),
- }
- )
- elif table == "tasks":
- for r in rows:
- facts["tasks"].append(
- {
- "id": r.get("id"),
- "title": r.get("title"),
- "status": r.get("status"),
- "projectId": r.get("projectId"),
- }
- )
- elif table == "notes":
- for r in rows:
- facts["notes"].append(
- {
- "id": r.get("id"),
- "title": r.get("title"),
- "projectId": r.get("projectId"),
- }
- )
- elif table == "timelines":
- for r in rows:
- facts["timelines"].append(
- {
- "id": r.get("id"),
- "title": r.get("title"),
- "date": r.get("date"),
- "projectId": r.get("projectId"),
- }
- )
-
- return facts
-
-
-async def _run_tool_loop(
- worker: WorkerName,
- instruction: str,
- context: dict[str, Any],
-) -> tuple[str, list[dict[str, Any]]]:
- worker_prompt = WORKER_CONFIG[worker]["prompt"]
- tools = WORKER_CONFIG[worker]["tools"]
-
- llm = get_llm()
- llm_with_tools = llm.bind_tools(tools) if tools else llm
-
- resolved_project_id = None
- ctx = context.get("context", {}) if isinstance(context, dict) else {}
- if isinstance(ctx, dict):
- rpid = ctx.get("resolved_project_id")
- if isinstance(rpid, str) and rpid:
- resolved_project_id = rpid
-
- mandatory_tool_policy = ""
- if resolved_project_id:
- if worker == "project_agent":
- mandatory_tool_policy = (
- "MANDATORY TOOL POLICY:\n"
- f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n"
- "- Optionally call list_projects afterward only if needed for disambiguation.\n\n"
- )
- elif worker == "task_agent":
- mandatory_tool_policy = (
- "MANDATORY TOOL POLICY:\n"
- f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n"
- "- Do not use project name as project_id.\n\n"
- )
- elif worker == "timeline_agent":
- mandatory_tool_policy = (
- "MANDATORY TOOL POLICY:\n"
- f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n"
- "- Do not use project name as project_id.\n\n"
- )
- elif worker == "note_agent":
- mandatory_tool_policy = (
- "MANDATORY TOOL POLICY:\n"
- f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n"
- "- Do not use project name as project_id.\n\n"
- )
-
- messages: list[Any] = [
- SystemMessage(content=worker_prompt),
- HumanMessage(
- content=(
- mandatory_tool_policy +
- "Worker instruction:\n"
- f"{instruction}\n\n"
- "Conversation context:\n"
- f"{json.dumps(context, ensure_ascii=True)[:2000]}"
+ f"User message:\n{message}\n\n"
+ f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
)
),
]
@@ -629,284 +164,133 @@ async def _run_tool_loop(
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
- for _ in range(6):
+ for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
- return _as_text(response.content), collected
+ return _as_text(response.content)
- tool_map = {t.name: t for t in tools}
+ tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
- "deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s",
- worker,
+ "deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
- tool_fn = tool_map.get(call["name"])
+ tool_fn = tool_map.get(call_name)
if tool_fn is None:
- tool_output = f"Unknown tool: {call['name']}"
+ tool_output = f"Unknown tool: {call_name}"
else:
- tool_output = await tool_fn.ainvoke(call.get("args", {}))
+ tool_output = await tool_fn.ainvoke(call_args)
- tool_output_text = str(tool_output)
logger.info(
- "deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s",
- worker,
+ "deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
- tool_output_text[:1200],
+ str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
- logger.info(
- "deep_agent: worker=%s appended ToolMessage tool_call_id=%s",
- worker,
- call_id,
- )
- structured_llm = llm.with_structured_output(WorkerSummary)
- messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
- final_summary = await structured_llm.ainvoke(messages)
-
- if isinstance(final_summary, WorkerSummary):
- return final_summary.summary, collected
- return str(final_summary), collected
+ final = await llm.ainvoke(messages)
+ return _as_text(final.content)
finally:
clear_tool_result_collector()
-def _worker_node(worker: WorkerName):
- async def _node(state: GraphState) -> AggregatedState:
- task_payload = state.get("task") or {}
- if task_payload.get("worker") != worker:
- return {"worker_results": []}
-
- instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
- logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240])
- worker_context = {
- "memory": state.get("memory_context", {}),
- "context": state.get("context", {}),
- }
- response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
- logger.info(
- "deep_agent: worker=%s complete tool_calls=%d entity_counts=%s",
- worker,
- len(tool_results),
- {k: len(v) for k, v in _extract_entity_ids(tool_results).items()},
- )
-
- return {
- "worker_results": [
- {
- "worker": worker,
- "instruction": instruction,
- "response": response,
- "entity_ids": _extract_entity_ids(tool_results),
- "facts": _extract_facts(tool_results),
- }
- ]
- }
-
- return _node
-
-
-def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
- worker_results = state.get("worker_results", [])
- formatted_results = []
- for result in worker_results:
- formatted_results.append(
- {
- "worker": result.get("worker"),
- "instruction": result.get("instruction"),
- "response": result.get("response"),
- "entity_ids": result.get("entity_ids", {}),
- "facts": result.get("facts", {}),
- }
- )
-
- payload = {
- "user_message": state.get("user_message", ""),
- "memory_context": state.get("memory_context", {}),
- "worker_results": formatted_results,
- "floating_domain": state.get("floating_domain") if floating else None,
- }
- return json.dumps(payload, ensure_ascii=True)
-
-
-async def _stream_with_memory_tool(
+async def _run_single_agent_stream(
*,
- user_id: str,
system_prompt: str,
- user_prompt: str,
-) -> str:
+ message: str,
+ context: dict[str, Any],
+ max_steps: int = 6,
+) -> AsyncGenerator[tuple[str, Any], None]:
llm = get_llm()
+ tools = _all_tools()
+ llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
- HumanMessage(content=user_prompt),
+ HumanMessage(
+ content=(
+ f"User message:\n{message}\n\n"
+ f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
+ )
+ ),
]
- chunks: list[str] = []
- async for chunk in llm.astream(messages):
- token = _as_text(getattr(chunk, "content", ""))
- if not token:
- continue
- chunks.append(token)
+ collected: list[dict[str, Any]] = []
+ set_tool_result_collector(collected)
+ try:
+ for _ in range(max_steps):
+ response: AIMessage = await llm_with_tools.ainvoke(messages)
+ messages.append(response)
- return "".join(chunks)
+ if not response.tool_calls:
+ async for chunk in llm.astream(messages):
+ token = _as_text(getattr(chunk, "content", ""))
+ if token:
+ yield "token", token
+ return
+ tool_map = {tool_def.name: tool_def for tool_def in tools}
+ for call in response.tool_calls:
+ call_id = str(call.get("id", ""))
+ call_name = str(call.get("name", ""))
+ call_args = call.get("args", {})
+ logger.info(
+ "deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
+ call_id,
+ call_name,
+ json.dumps(call_args, ensure_ascii=True)[:800],
+ )
-def _synthesizer_node(floating: bool):
- async def _node(state: GraphState) -> GraphState:
- prompt = _build_synthesis_prompt(state, floating=floating)
- system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM
+ tool_fn = tool_map.get(call_name)
+ if tool_fn is None:
+ tool_output = f"Unknown tool: {call_name}"
+ else:
+ tool_output = await tool_fn.ainvoke(call_args)
- final_response = await _stream_with_memory_tool(
- user_id=str(state.get("user_id", "")),
- system_prompt=system_prompt,
- user_prompt=prompt,
- )
+ logger.info(
+ "deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
+ call_id,
+ call_name,
+ str(tool_output)[:1200],
+ )
- return {"final_response": final_response}
+ messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
- return _node
-
-
-async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]:
- if not updates:
- return current_memory
-
- new_memory = dict(current_memory)
- async with async_session() as db:
- memory = MemoryMiddleware(db)
- for update in updates:
- await memory.update_core(user_id, update.key, update.value)
- new_memory[update.key] = update.value
- return new_memory
-
-async def _orchestrator_node_home(state: GraphState) -> GraphState:
- if state.get("plan"):
- return {}
-
- user_message = str(state.get("user_message", ""))
- base_context = dict(state.get("context", {}))
- context = {**base_context, **state.get("memory_context", {})}
-
- if _needs_full_project_snapshot(user_message, floating=False):
- resolved_project_id = await _resolve_project_id_from_message(user_message)
- if resolved_project_id:
- base_context["resolved_project_id"] = resolved_project_id
- logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200])
- plan = _build_full_project_snapshot_plan(user_message)
- else:
- plan = await _plan_with_llm(user_message, context, floating=False)
-
- new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
-
- return {
- "plan": [task.model_dump() for task in plan.tasks],
- "memory_context": new_memory,
- "context": base_context,
- }
-
-
-async def _orchestrator_node_floating(state: GraphState) -> GraphState:
- if state.get("plan"):
- return {}
-
- context = {**state.get("context", {}), **state.get("memory_context", {})}
- plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True)
- floating_domain = plan.floating_domain
- if floating_domain is None and plan.tasks:
- floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
-
- new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
-
- return {
- "plan": [task.model_dump() for task in plan.tasks],
- "floating_domain": floating_domain or "tasks",
- "memory_context": new_memory
- }
-
-
-def _route_workers(state: GraphState) -> list[Send] | str:
- plan = state.get("plan", [])
- if not plan:
- return "synthesizer"
-
- sends: list[Send] = []
- for task in plan:
- worker = task.get("worker")
- if worker in WORKER_CONFIG:
- sends.append(Send(worker, {"task": task}))
-
- return sends or "synthesizer"
-
-
-def _build_graph(*, floating: bool):
- builder = StateGraph(GraphState)
-
- orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home
- builder.add_node("orchestrator", orchestrator_node)
- for worker in WORKER_CONFIG:
- builder.add_node(worker, _worker_node(worker))
- builder.add_node("synthesizer", _synthesizer_node(floating=floating))
-
- builder.add_edge(START, "orchestrator")
- builder.add_conditional_edges(
- "orchestrator",
- _route_workers,
- ["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"],
- )
- for worker in WORKER_CONFIG:
- builder.add_edge(worker, "synthesizer")
- builder.add_edge("synthesizer", END)
-
- return builder.compile()
-
-
-HOME_GRAPH = _build_graph(floating=False)
-FLOATING_GRAPH = _build_graph(floating=True)
+ async for chunk in llm.astream(messages):
+ token = _as_text(getattr(chunk, "content", ""))
+ if token:
+ yield "token", token
+ finally:
+ clear_tool_result_collector()
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
- if HOME_SINGLE_AGENT_TEST_MODE:
- return await _run_home_single_agent(user_id, message, context)
-
- state = await HOME_GRAPH.ainvoke(
- {
- "user_id": user_id,
- "user_message": message,
- "context": context,
- "memory_context": context,
- "worker_results": [],
- }
+ prepared_context = await _prepare_context(message, context)
+ return await _run_single_agent(
+ system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
+ message=message,
+ context=prepared_context,
)
- return str(state.get("final_response", ""))
+
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
- plan = await _plan_with_llm(message, context, floating=True)
- domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
- new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
-
- state = await FLOATING_GRAPH.ainvoke(
- {
- "user_id": user_id,
- "user_message": message,
- "context": context,
- "memory_context": new_memory,
- "plan": [task.model_dump() for task in plan.tasks],
- "floating_domain": domain,
- "worker_results": [],
- }
+ domain = _infer_floating_domain(message, context)
+ prepared_context = await _prepare_context(message, context)
+ response = await _run_single_agent(
+ system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
+ message=message,
+ context=prepared_context,
)
- return str(state.get("final_response", "")), str(domain)
+ return response, domain
async def run_home_stream(
@@ -914,60 +298,34 @@ async def run_home_stream(
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
- if HOME_SINGLE_AGENT_TEST_MODE:
- async for event in _run_home_single_agent_stream(user_id, message, context):
- yield event
- return
+ prepared_context = await _prepare_context(message, context)
+ async for event in _run_single_agent_stream(
+ system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
+ message=message,
+ context=prepared_context,
+ ):
+ yield event
- state_input = {
- "user_id": user_id,
- "user_message": message,
- "context": context,
- "memory_context": context,
- "worker_results": [],
- }
-
- async for event in HOME_GRAPH.astream_events(state_input, version="v2"):
- kind = event["event"]
-
- if kind == "on_chat_model_stream":
- node_name = event.get("metadata", {}).get("langgraph_node")
-
- if node_name == "synthesizer":
- chunk = event["data"]["chunk"]
- token = _as_text(getattr(chunk, "content", ""))
- if token:
- yield "token", token
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
- plan = await _plan_with_llm(message, context, floating=True)
- domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
+ domain = _infer_floating_domain(message, context)
yield "floating_domain", domain
- new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context)
+ prepared_context = await _prepare_context(message, context)
+ async for event in _run_single_agent_stream(
+ system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
+ message=message,
+ context=prepared_context,
+ ):
+ yield event
- state_input = {
- "user_id": user_id,
- "user_message": message,
- "context": context,
- "memory_context": new_memory,
- "plan": [t.model_dump() for t in plan.tasks],
- "floating_domain": domain,
- "worker_results": [],
- }
- async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"):
- kind = event["event"]
-
- if kind == "on_chat_model_stream":
- node_name = event.get("metadata", {}).get("langgraph_node")
-
- if node_name == "synthesizer":
- chunk = event["data"]["chunk"]
- token = _as_text(getattr(chunk, "content", ""))
- if token:
- yield "token", token
+async def update_core_memory(user_id: str, key: str, value: str) -> None:
+ """Compatibility helper kept for callers that expect explicit memory update API."""
+ async with async_session() as db:
+ memory = MemoryMiddleware(db)
+ await memory.update_core(user_id, key, value)
diff --git a/requirements.txt b/requirements.txt
index 8202519..ea10f59 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,7 +5,6 @@ langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
litellm>=1.50.0
-langgraph>=0.4.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
diff --git a/tests/test_deep_agent.py b/tests/test_deep_agent.py
new file mode 100644
index 0000000..deddfa3
--- /dev/null
+++ b/tests/test_deep_agent.py
@@ -0,0 +1,81 @@
+"""Unit tests for single-agent deep_agent flows with mocked tool results."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from unittest.mock import patch
+
+import pytest
+from langchain_core.messages import AIMessage, ToolMessage
+
+from app.core.deep_agent import run_floating_stream, run_home
+
+
+class _FakeTool:
+ name = "list_tasks"
+
+ async def ainvoke(self, args):
+ return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
+
+
+class _FakeLLM:
+ def __init__(self) -> None:
+ self.calls = 0
+
+ def bind_tools(self, _tools):
+ return self
+
+ async def ainvoke(self, messages):
+ self.calls += 1
+ if self.calls == 1:
+ return AIMessage(
+ content="",
+ tool_calls=[
+ {
+ "id": "call-1",
+ "name": "list_tasks",
+ "args": {"project_id": "proj-1"},
+ }
+ ],
+ )
+
+ tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
+ assert tool_messages, "Expected at least one tool message"
+ return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
+
+ async def astream(self, _messages):
+ yield SimpleNamespace(content="stream-")
+ yield SimpleNamespace(content="ok")
+
+
+@pytest.mark.asyncio
+async def test_run_home_uses_mocked_tool_result():
+ fake_llm = _FakeLLM()
+
+ with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
+ "app.core.deep_agent._all_tools", return_value=[_FakeTool()]
+ ):
+ out = await run_home("user-1", "list my tasks", {})
+
+ assert "Final answer from mocked tool" in out
+ assert "Mock Task" in out
+
+
+@pytest.mark.asyncio
+async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
+ fake_llm = _FakeLLM()
+
+ with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
+ "app.core.deep_agent._all_tools", return_value=[_FakeTool()]
+ ):
+ events = []
+ async for event in run_floating_stream(
+ "user-1",
+ "show me timeline updates",
+ {"scope": {"type": "timeline", "id": "tl-1"}},
+ ):
+ events.append(event)
+
+ assert events[0] == ("floating_domain", "timelines")
+ assert ("token", "stream-") in events
+ assert ("token", "ok") in events