From a1e364c9c061427d8ebb4eebf9fdb23c098b2790 Mon Sep 17 00:00:00 2001 From: roberto Date: Fri, 13 Mar 2026 08:20:42 +0100 Subject: [PATCH] refactor: switch to single-agent deep runner and add mock memory/tool tests --- app/core/deep_agent.py | 950 +++++++-------------------------------- requirements.txt | 1 - tests/test_deep_agent.py | 81 ++++ 3 files changed, 235 insertions(+), 797 deletions(-) create mode 100644 tests/test_deep_agent.py diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index 52f5166..22559a4 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -1,26 +1,19 @@ -"""Deep orchestrator-worker graphs for home and floating chat contexts.""" +"""Single-agent runners for home and floating chat contexts.""" from __future__ import annotations -import asyncio import json import logging import re -from collections.abc import AsyncGenerator, Awaitable, Callable -import operator -from typing import Annotated, Any, Literal, TypedDict +from collections.abc import AsyncGenerator +from typing import Any, Literal from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage -from langchain_core.tools import tool -from langgraph.constants import END, START -from langgraph.graph import StateGraph -from langgraph.types import Send -from pydantic import BaseModel, Field -from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS -from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS -from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS -from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS +from app.agents.note_agent import NOTE_TOOLS +from app.agents.project_agent import PROJECT_TOOLS +from app.agents.task_agent import TASK_TOOLS +from app.agents.timeline_agent import TIMELINE_TOOLS from app.core.llm import get_llm from app.core.memory_middleware import MemoryMiddleware from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector @@ -28,121 +21,8 @@ from app.db import async_session logger = logging.getLogger(__name__) -# Quick test switch: home requests run as one agent with all tools. -HOME_SINGLE_AGENT_TEST_MODE = True - -WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"] FloatingDomain = Literal["tasks", "projects", "notes", "timelines"] - -class WorkerTask(BaseModel): - worker: WorkerName - instruction: str - - -class MemoryUpdate(BaseModel): - key: str = Field(description="The memory key to set or update.") - value: str = Field(description="The persistent fact or preference value.") - - -class WorkerSummary(BaseModel): - summary: str = Field(description="Strictly concise summary of tool findings. Max 3 sentences.") - - -class WorkerPlan(BaseModel): - tasks: list[WorkerTask] = Field(default_factory=list) - floating_domain: FloatingDomain | None = None - memory_updates: list[MemoryUpdate] = Field(default_factory=list, description="Update long-term core memory with persistent user preferences/facts learned from this message.") - - -class WorkerResult(TypedDict): - worker: WorkerName - instruction: str - response: str - entity_ids: dict[str, list[str]] - facts: dict[str, Any] - - -class OrchestratorState(TypedDict, total=False): - user_id: str - user_message: str - context: dict[str, Any] - memory_context: dict[str, Any] - plan: list[dict[str, Any]] - floating_domain: FloatingDomain - task: dict[str, Any] - worker_results: list[WorkerResult] - final_response: str - - -class GraphState(OrchestratorState): - worker_results: Annotated[list[WorkerResult], operator.add] - - -class ReducerState(OrchestratorState): - worker_results: list[WorkerResult] - - -class AggregatedState(TypedDict, total=False): - worker_results: list[WorkerResult] - - -WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = { - "task_agent": { - "prompt": TASK_SYSTEM_PROMPT, - "tools": TASK_TOOLS, - "tag": "task", - "table": "tasks", - "floating_domain": "tasks", - }, - "project_agent": { - "prompt": PROJECT_SYSTEM_PROMPT, - "tools": PROJECT_TOOLS, - "tag": "project", - "table": "projects", - "floating_domain": "projects", - }, - "note_agent": { - "prompt": NOTE_SYSTEM_PROMPT, - "tools": NOTE_TOOLS, - "tag": "note", - "table": "notes", - "floating_domain": "notes", - }, - "timeline_agent": { - "prompt": TIMELINE_SYSTEM_PROMPT, - "tools": TIMELINE_TOOLS, - "tag": "timeline", - "table": "timelines", - "floating_domain": "timelines", - }, -} - -_HOME_ORCHESTRATOR_SYSTEM = ( - "You are an orchestrator. Plan which workers should be invoked for the user request. " - "Workers: task_agent, project_agent, note_agent, timeline_agent. " - "Return JSON only with keys: tasks, floating_domain, memory_updates." -) - -_FLOATING_ORCHESTRATOR_SYSTEM = ( - "You are an orchestrator for floating context. Pick focused workers and set floating_domain " - "as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates." -) - -_HOME_SYNTH_SYSTEM = ( - "You are the final response synthesizer. Return markdown only. " - "Embed inline component tags when relevant: [ids], [ids], " - "[ids], [ids], and {json}. " - "Only include IDs that are truly relevant to the request. " - "Never invent missing values. If facts include a non-null clientId for a project, " - "do not claim that the project has no owner/client." -) - -_FLOATING_SYNTH_SYSTEM = ( - "You are the final response synthesizer for floating UI context. " - "Return concise markdown and stay focused on the requested scope." -) - _HOME_SINGLE_AGENT_SYSTEM = ( "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. " "Always use tools for factual data retrieval before answering. " @@ -151,6 +31,15 @@ _HOME_SINGLE_AGENT_SYSTEM = ( "[ids], [ids], {json}." ) +_FLOATING_SINGLE_AGENT_SYSTEM = ( + "You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. " + "Stay focused on the floating scope in context.scope and answer concisely. " + "Always use tools for factual data retrieval before answering. " + "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " + "Return markdown and embed inline tags when relevant: [ids], [ids], " + "[ids], [ids], {json}." +) + def _as_text(content: Any) -> str: if content is None: @@ -170,130 +59,9 @@ def _as_text(content: Any) -> str: return str(content) -def _fallback_plan(message: str, floating: bool) -> WorkerPlan: - lowered = message.lower() - tasks: list[WorkerTask] = [] - - if any(k in lowered for k in ["task", "todo", "deadline", "due"]): - tasks.append(WorkerTask(worker="task_agent", instruction=message)) - if any(k in lowered for k in ["project", "client", "milestone"]): - tasks.append(WorkerTask(worker="project_agent", instruction=message)) - if any(k in lowered for k in ["note", "document", "memo"]): - tasks.append(WorkerTask(worker="note_agent", instruction=message)) - if any(k in lowered for k in ["timeline", "event", "schedule", "release"]): - tasks.append(WorkerTask(worker="timeline_agent", instruction=message)) - - if not tasks: - tasks = [WorkerTask(worker="task_agent", instruction=message)] - - domain: FloatingDomain | None = None - if floating: - domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"] - - return WorkerPlan(tasks=tasks, floating_domain=domain) - - -def _extract_json_object(text: str) -> dict[str, Any] | None: - """Best-effort extraction of the first JSON object from model output.""" - stripped = text.strip() - if not stripped: - return None - - # Common case: model returns raw JSON object. - try: - payload = json.loads(stripped) - if isinstance(payload, dict): - return payload - except json.JSONDecodeError: - pass - - # Fenced JSON block fallback. - if "```" in stripped: - parts = stripped.split("```") - for part in parts: - candidate = part.strip() - if candidate.startswith("json"): - candidate = candidate[4:].strip() - try: - payload = json.loads(candidate) - if isinstance(payload, dict): - return payload - except json.JSONDecodeError: - continue - - return None - - -def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan: - """Normalize loose model JSON into a validated WorkerPlan.""" - tasks_raw = payload.get("tasks") - tasks: list[WorkerTask] = [] - - if isinstance(tasks_raw, list): - for item in tasks_raw: - if not isinstance(item, dict): - continue - worker = item.get("worker") - instruction = item.get("instruction") - if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str): - tasks.append(WorkerTask(worker=worker, instruction=instruction)) - - if not tasks: - return _fallback_plan(message, floating) - - domain = payload.get("floating_domain") - floating_domain: FloatingDomain | None = None - if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}: - floating_domain = domain # type: ignore[assignment] - elif floating: - floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"] - - memory_updates: list[MemoryUpdate] = [] - updates_raw = payload.get("memory_updates") - if isinstance(updates_raw, list): - for item in updates_raw: - if isinstance(item, dict): - key = item.get("key") - value = item.get("value") - if isinstance(key, str) and isinstance(value, str) and key and value: - memory_updates.append(MemoryUpdate(key=key, value=value)) - - return WorkerPlan( - tasks=tasks, - floating_domain=floating_domain, - memory_updates=memory_updates, - ) - - -def _needs_full_project_snapshot(message: str, floating: bool) -> bool: - """Detect project status/update requests that should query all workers.""" - if floating: - return False - lowered = message.lower() - has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"]) - has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"]) - return has_project and has_status_intent - - -def _build_full_project_snapshot_plan(message: str) -> WorkerPlan: - """Build a deterministic all-workers plan for project status snapshots.""" - project_hint = ( - "Use context.context.resolved_project_id when present as project_id. " - "Do not pass project names as project_id." - ) - return WorkerPlan( - tasks=[ - WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"), - WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"), - WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"), - WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"), - ] - ) - - def _candidate_tokens(message: str) -> list[str]: tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower()) - return [t for t in tokens if len(t) >= 3] + return [token for token in tokens if len(token) >= 3] async def _resolve_project_id_from_message(message: str) -> str | None: @@ -331,297 +99,64 @@ async def _resolve_project_id_from_message(message: str) -> str | None: return project_id if isinstance(project_id, str) else None -async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]: - """Resolve and inject project_id hints for home flows.""" +def _needs_project_resolution(message: str) -> bool: + lowered = message.lower() + return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"]) + + +async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]: prepared = dict(context) - if _needs_full_project_snapshot(message, floating=False): + if _needs_project_resolution(message): resolved_project_id = await _resolve_project_id_from_message(message) if resolved_project_id: prepared["resolved_project_id"] = resolved_project_id - logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200]) + logger.info("deep_agent: resolved_project_id=%s", resolved_project_id) return prepared def _all_tools() -> list[Any]: - tools: list[Any] = [] - for config in WORKER_CONFIG.values(): - tools.extend(config["tools"]) - return tools + return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS] -async def _run_home_single_agent( - user_id: str, +def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain: + scope = context.get("scope") if isinstance(context, dict) else None + if isinstance(scope, dict): + scope_type = str(scope.get("type") or "").strip().lower() + if scope_type in {"task", "tasks"}: + return "tasks" + if scope_type in {"project", "projects"}: + return "projects" + if scope_type in {"note", "notes"}: + return "notes" + if scope_type in {"timeline", "timelines"}: + return "timelines" + + lowered = message.lower() + if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]): + return "timelines" + if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]): + return "notes" + if any(keyword in lowered for keyword in ["project", "progetto", "client"]): + return "projects" + return "tasks" + + +async def _run_single_agent( + *, + system_prompt: str, message: str, context: dict[str, Any], + max_steps: int = 6, ) -> str: - """Single-agent test mode: one loop with all tools.""" - prepared_context = await _prepare_home_context(message, context) - llm = get_llm() tools = _all_tools() llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ - SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM), - HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"), - ] - - for _ in range(6): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) - if not response.tool_calls: - return _as_text(response.content) - - tool_map = {t.name: t for t in tools} - for call in response.tool_calls: - tool_fn = tool_map.get(call["name"]) - if tool_fn is None: - tool_output = f"Unknown tool: {call['name']}" - else: - tool_output = await tool_fn.ainvoke(call.get("args", {})) - messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - - final = await llm.ainvoke(messages) - return _as_text(final.content) - - -async def _run_home_single_agent_stream( - user_id: str, - message: str, - context: dict[str, Any], -) -> AsyncGenerator[tuple[str, Any], None]: - """Streaming variant for single-agent home test mode.""" - prepared_context = await _prepare_home_context(message, context) - - llm = get_llm() - tools = _all_tools() - llm_with_tools = llm.bind_tools(tools) - messages: list[Any] = [ - SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM), - HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"), - ] - - for _ in range(6): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) - if not response.tool_calls: - async for chunk in llm.astream(messages): - token = _as_text(getattr(chunk, "content", "")) - if token: - yield "token", token - return - - tool_map = {t.name: t for t in tools} - for call in response.tool_calls: - tool_fn = tool_map.get(call["name"]) - if tool_fn is None: - tool_output = f"Unknown tool: {call['name']}" - else: - tool_output = await tool_fn.ainvoke(call.get("args", {})) - messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - - async for chunk in llm.astream(messages): - token = _as_text(getattr(chunk, "content", "")) - if token: - yield "token", token - - -async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan: - if _needs_full_project_snapshot(message, floating): - logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200]) - return _build_full_project_snapshot_plan(message) - - llm = get_llm() - system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM - - prompt_payload = { - "message": message, - "context": context, - "workers": list(WORKER_CONFIG.keys()), - } - messages = [ - SystemMessage(content=system), + SystemMessage(content=system_prompt), HumanMessage( content=( - "Create a valid JSON object with this exact structure:\n" - '{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],' - '"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n' - "Rules:\n" - "- tasks must include at least one entry when possible\n" - "- use floating_domain only when relevant\n" - "- output JSON only (no markdown, no prose)\n\n" - f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}" - ) - ), - ] - - try: - response = await llm.ainvoke(messages) - payload = _extract_json_object(_as_text(response.content)) - if payload is None: - raise ValueError("planner returned non-JSON output") - plan = _coerce_plan(payload, message, floating) - logger.info( - "deep_agent: planner produced tasks=%s floating=%s", - [t.worker for t in plan.tasks], - plan.floating_domain, - ) - return plan - except Exception as exc: - logger.warning("deep_agent: planner failed, using fallback: %s", exc) - - return _fallback_plan(message, floating) - - -def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]: - out: dict[str, list[str]] = { - "task": [], - "project": [], - "note": [], - "timeline": [], - } - table_to_tag = { - "tasks": "task", - "projects": "project", - "notes": "note", - "timelines": "timeline", - } - - for item in tool_results: - table = item.get("table") - tag = table_to_tag.get(table) - if tag is None: - continue - - payload = item.get("data") or {} - rows: list[dict[str, Any]] = [] - row = payload.get("row") - if isinstance(row, dict): - rows.append(row) - if isinstance(payload.get("rows"), list): - rows.extend([r for r in payload["rows"] if isinstance(r, dict)]) - if isinstance(payload.get("results"), list): - rows.extend([r for r in payload["results"] if isinstance(r, dict)]) - - for r in rows: - entity_id = r.get("id") - if isinstance(entity_id, str) and entity_id not in out[tag]: - out[tag].append(entity_id) - - return out - - -def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]: - """Extract small, structured facts for the synthesizer to avoid hallucinations.""" - facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []} - - for item in tool_results: - table = item.get("table") - payload = item.get("data") or {} - - rows: list[dict[str, Any]] = [] - row = payload.get("row") - if isinstance(row, dict): - rows.append(row) - if isinstance(payload.get("rows"), list): - rows.extend([r for r in payload["rows"] if isinstance(r, dict)]) - - if table == "projects": - for r in rows: - facts["projects"].append( - { - "id": r.get("id"), - "name": r.get("name"), - "status": r.get("status"), - "clientId": r.get("clientId"), - } - ) - elif table == "tasks": - for r in rows: - facts["tasks"].append( - { - "id": r.get("id"), - "title": r.get("title"), - "status": r.get("status"), - "projectId": r.get("projectId"), - } - ) - elif table == "notes": - for r in rows: - facts["notes"].append( - { - "id": r.get("id"), - "title": r.get("title"), - "projectId": r.get("projectId"), - } - ) - elif table == "timelines": - for r in rows: - facts["timelines"].append( - { - "id": r.get("id"), - "title": r.get("title"), - "date": r.get("date"), - "projectId": r.get("projectId"), - } - ) - - return facts - - -async def _run_tool_loop( - worker: WorkerName, - instruction: str, - context: dict[str, Any], -) -> tuple[str, list[dict[str, Any]]]: - worker_prompt = WORKER_CONFIG[worker]["prompt"] - tools = WORKER_CONFIG[worker]["tools"] - - llm = get_llm() - llm_with_tools = llm.bind_tools(tools) if tools else llm - - resolved_project_id = None - ctx = context.get("context", {}) if isinstance(context, dict) else {} - if isinstance(ctx, dict): - rpid = ctx.get("resolved_project_id") - if isinstance(rpid, str) and rpid: - resolved_project_id = rpid - - mandatory_tool_policy = "" - if resolved_project_id: - if worker == "project_agent": - mandatory_tool_policy = ( - "MANDATORY TOOL POLICY:\n" - f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n" - "- Optionally call list_projects afterward only if needed for disambiguation.\n\n" - ) - elif worker == "task_agent": - mandatory_tool_policy = ( - "MANDATORY TOOL POLICY:\n" - f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n" - "- Do not use project name as project_id.\n\n" - ) - elif worker == "timeline_agent": - mandatory_tool_policy = ( - "MANDATORY TOOL POLICY:\n" - f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n" - "- Do not use project name as project_id.\n\n" - ) - elif worker == "note_agent": - mandatory_tool_policy = ( - "MANDATORY TOOL POLICY:\n" - f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n" - "- Do not use project name as project_id.\n\n" - ) - - messages: list[Any] = [ - SystemMessage(content=worker_prompt), - HumanMessage( - content=( - mandatory_tool_policy + - "Worker instruction:\n" - f"{instruction}\n\n" - "Conversation context:\n" - f"{json.dumps(context, ensure_ascii=True)[:2000]}" + f"User message:\n{message}\n\n" + f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}" ) ), ] @@ -629,284 +164,133 @@ async def _run_tool_loop( collected: list[dict[str, Any]] = [] set_tool_result_collector(collected) try: - for _ in range(6): + for _ in range(max_steps): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: - return _as_text(response.content), collected + return _as_text(response.content) - tool_map = {t.name: t for t in tools} + tool_map = {tool_def.name: tool_def for tool_def in tools} for call in response.tool_calls: call_id = str(call.get("id", "")) call_name = str(call.get("name", "")) call_args = call.get("args", {}) logger.info( - "deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s", - worker, + "deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s", call_id, call_name, json.dumps(call_args, ensure_ascii=True)[:800], ) - tool_fn = tool_map.get(call["name"]) + tool_fn = tool_map.get(call_name) if tool_fn is None: - tool_output = f"Unknown tool: {call['name']}" + tool_output = f"Unknown tool: {call_name}" else: - tool_output = await tool_fn.ainvoke(call.get("args", {})) + tool_output = await tool_fn.ainvoke(call_args) - tool_output_text = str(tool_output) logger.info( - "deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s", - worker, + "deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s", call_id, call_name, - tool_output_text[:1200], + str(tool_output)[:1200], ) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - logger.info( - "deep_agent: worker=%s appended ToolMessage tool_call_id=%s", - worker, - call_id, - ) - structured_llm = llm.with_structured_output(WorkerSummary) - messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences.")) - final_summary = await structured_llm.ainvoke(messages) - - if isinstance(final_summary, WorkerSummary): - return final_summary.summary, collected - return str(final_summary), collected + final = await llm.ainvoke(messages) + return _as_text(final.content) finally: clear_tool_result_collector() -def _worker_node(worker: WorkerName): - async def _node(state: GraphState) -> AggregatedState: - task_payload = state.get("task") or {} - if task_payload.get("worker") != worker: - return {"worker_results": []} - - instruction = str(task_payload.get("instruction") or state.get("user_message") or "") - logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240]) - worker_context = { - "memory": state.get("memory_context", {}), - "context": state.get("context", {}), - } - response, tool_results = await _run_tool_loop(worker, instruction, worker_context) - logger.info( - "deep_agent: worker=%s complete tool_calls=%d entity_counts=%s", - worker, - len(tool_results), - {k: len(v) for k, v in _extract_entity_ids(tool_results).items()}, - ) - - return { - "worker_results": [ - { - "worker": worker, - "instruction": instruction, - "response": response, - "entity_ids": _extract_entity_ids(tool_results), - "facts": _extract_facts(tool_results), - } - ] - } - - return _node - - -def _build_synthesis_prompt(state: GraphState, floating: bool) -> str: - worker_results = state.get("worker_results", []) - formatted_results = [] - for result in worker_results: - formatted_results.append( - { - "worker": result.get("worker"), - "instruction": result.get("instruction"), - "response": result.get("response"), - "entity_ids": result.get("entity_ids", {}), - "facts": result.get("facts", {}), - } - ) - - payload = { - "user_message": state.get("user_message", ""), - "memory_context": state.get("memory_context", {}), - "worker_results": formatted_results, - "floating_domain": state.get("floating_domain") if floating else None, - } - return json.dumps(payload, ensure_ascii=True) - - -async def _stream_with_memory_tool( +async def _run_single_agent_stream( *, - user_id: str, system_prompt: str, - user_prompt: str, -) -> str: + message: str, + context: dict[str, Any], + max_steps: int = 6, +) -> AsyncGenerator[tuple[str, Any], None]: llm = get_llm() + tools = _all_tools() + llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ SystemMessage(content=system_prompt), - HumanMessage(content=user_prompt), + HumanMessage( + content=( + f"User message:\n{message}\n\n" + f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}" + ) + ), ] - chunks: list[str] = [] - async for chunk in llm.astream(messages): - token = _as_text(getattr(chunk, "content", "")) - if not token: - continue - chunks.append(token) + collected: list[dict[str, Any]] = [] + set_tool_result_collector(collected) + try: + for _ in range(max_steps): + response: AIMessage = await llm_with_tools.ainvoke(messages) + messages.append(response) - return "".join(chunks) + if not response.tool_calls: + async for chunk in llm.astream(messages): + token = _as_text(getattr(chunk, "content", "")) + if token: + yield "token", token + return + tool_map = {tool_def.name: tool_def for tool_def in tools} + for call in response.tool_calls: + call_id = str(call.get("id", "")) + call_name = str(call.get("name", "")) + call_args = call.get("args", {}) + logger.info( + "deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s", + call_id, + call_name, + json.dumps(call_args, ensure_ascii=True)[:800], + ) -def _synthesizer_node(floating: bool): - async def _node(state: GraphState) -> GraphState: - prompt = _build_synthesis_prompt(state, floating=floating) - system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM + tool_fn = tool_map.get(call_name) + if tool_fn is None: + tool_output = f"Unknown tool: {call_name}" + else: + tool_output = await tool_fn.ainvoke(call_args) - final_response = await _stream_with_memory_tool( - user_id=str(state.get("user_id", "")), - system_prompt=system_prompt, - user_prompt=prompt, - ) + logger.info( + "deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s", + call_id, + call_name, + str(tool_output)[:1200], + ) - return {"final_response": final_response} + messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - return _node - - -async def _apply_memory_updates(user_id: str, updates: list[MemoryUpdate], current_memory: dict[str, Any]) -> dict[str, Any]: - if not updates: - return current_memory - - new_memory = dict(current_memory) - async with async_session() as db: - memory = MemoryMiddleware(db) - for update in updates: - await memory.update_core(user_id, update.key, update.value) - new_memory[update.key] = update.value - return new_memory - -async def _orchestrator_node_home(state: GraphState) -> GraphState: - if state.get("plan"): - return {} - - user_message = str(state.get("user_message", "")) - base_context = dict(state.get("context", {})) - context = {**base_context, **state.get("memory_context", {})} - - if _needs_full_project_snapshot(user_message, floating=False): - resolved_project_id = await _resolve_project_id_from_message(user_message) - if resolved_project_id: - base_context["resolved_project_id"] = resolved_project_id - logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200]) - plan = _build_full_project_snapshot_plan(user_message) - else: - plan = await _plan_with_llm(user_message, context, floating=False) - - new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {})) - - return { - "plan": [task.model_dump() for task in plan.tasks], - "memory_context": new_memory, - "context": base_context, - } - - -async def _orchestrator_node_floating(state: GraphState) -> GraphState: - if state.get("plan"): - return {} - - context = {**state.get("context", {}), **state.get("memory_context", {})} - plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True) - floating_domain = plan.floating_domain - if floating_domain is None and plan.tasks: - floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] - - new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {})) - - return { - "plan": [task.model_dump() for task in plan.tasks], - "floating_domain": floating_domain or "tasks", - "memory_context": new_memory - } - - -def _route_workers(state: GraphState) -> list[Send] | str: - plan = state.get("plan", []) - if not plan: - return "synthesizer" - - sends: list[Send] = [] - for task in plan: - worker = task.get("worker") - if worker in WORKER_CONFIG: - sends.append(Send(worker, {"task": task})) - - return sends or "synthesizer" - - -def _build_graph(*, floating: bool): - builder = StateGraph(GraphState) - - orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home - builder.add_node("orchestrator", orchestrator_node) - for worker in WORKER_CONFIG: - builder.add_node(worker, _worker_node(worker)) - builder.add_node("synthesizer", _synthesizer_node(floating=floating)) - - builder.add_edge(START, "orchestrator") - builder.add_conditional_edges( - "orchestrator", - _route_workers, - ["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"], - ) - for worker in WORKER_CONFIG: - builder.add_edge(worker, "synthesizer") - builder.add_edge("synthesizer", END) - - return builder.compile() - - -HOME_GRAPH = _build_graph(floating=False) -FLOATING_GRAPH = _build_graph(floating=True) + async for chunk in llm.astream(messages): + token = _as_text(getattr(chunk, "content", "")) + if token: + yield "token", token + finally: + clear_tool_result_collector() async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: - if HOME_SINGLE_AGENT_TEST_MODE: - return await _run_home_single_agent(user_id, message, context) - - state = await HOME_GRAPH.ainvoke( - { - "user_id": user_id, - "user_message": message, - "context": context, - "memory_context": context, - "worker_results": [], - } + prepared_context = await _prepare_context(message, context) + return await _run_single_agent( + system_prompt=_HOME_SINGLE_AGENT_SYSTEM, + message=message, + context=prepared_context, ) - return str(state.get("final_response", "")) + async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: - plan = await _plan_with_llm(message, context, floating=True) - domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] - new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context) - - state = await FLOATING_GRAPH.ainvoke( - { - "user_id": user_id, - "user_message": message, - "context": context, - "memory_context": new_memory, - "plan": [task.model_dump() for task in plan.tasks], - "floating_domain": domain, - "worker_results": [], - } + domain = _infer_floating_domain(message, context) + prepared_context = await _prepare_context(message, context) + response = await _run_single_agent( + system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, + message=message, + context=prepared_context, ) - return str(state.get("final_response", "")), str(domain) + return response, domain async def run_home_stream( @@ -914,60 +298,34 @@ async def run_home_stream( message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: - if HOME_SINGLE_AGENT_TEST_MODE: - async for event in _run_home_single_agent_stream(user_id, message, context): - yield event - return + prepared_context = await _prepare_context(message, context) + async for event in _run_single_agent_stream( + system_prompt=_HOME_SINGLE_AGENT_SYSTEM, + message=message, + context=prepared_context, + ): + yield event - state_input = { - "user_id": user_id, - "user_message": message, - "context": context, - "memory_context": context, - "worker_results": [], - } - - async for event in HOME_GRAPH.astream_events(state_input, version="v2"): - kind = event["event"] - - if kind == "on_chat_model_stream": - node_name = event.get("metadata", {}).get("langgraph_node") - - if node_name == "synthesizer": - chunk = event["data"]["chunk"] - token = _as_text(getattr(chunk, "content", "")) - if token: - yield "token", token async def run_floating_stream( user_id: str, message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: - plan = await _plan_with_llm(message, context, floating=True) - domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"] + domain = _infer_floating_domain(message, context) yield "floating_domain", domain - new_memory = await _apply_memory_updates(user_id, plan.memory_updates, context) + prepared_context = await _prepare_context(message, context) + async for event in _run_single_agent_stream( + system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, + message=message, + context=prepared_context, + ): + yield event - state_input = { - "user_id": user_id, - "user_message": message, - "context": context, - "memory_context": new_memory, - "plan": [t.model_dump() for t in plan.tasks], - "floating_domain": domain, - "worker_results": [], - } - async for event in FLOATING_GRAPH.astream_events(state_input, version="v2"): - kind = event["event"] - - if kind == "on_chat_model_stream": - node_name = event.get("metadata", {}).get("langgraph_node") - - if node_name == "synthesizer": - chunk = event["data"]["chunk"] - token = _as_text(getattr(chunk, "content", "")) - if token: - yield "token", token +async def update_core_memory(user_id: str, key: str, value: str) -> None: + """Compatibility helper kept for callers that expect explicit memory update API.""" + async with async_session() as db: + memory = MemoryMiddleware(db) + await memory.update_core(user_id, key, value) diff --git a/requirements.txt b/requirements.txt index 8202519..ea10f59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ langchain>=0.3.0 langchain-openai>=0.3.0 langchain-litellm>=0.1.0 litellm>=1.50.0 -langgraph>=0.4.0 pydantic>=2.10.0 pydantic-settings>=2.7.0 python-jose[cryptography]>=3.3.0 diff --git a/tests/test_deep_agent.py b/tests/test_deep_agent.py new file mode 100644 index 0000000..deddfa3 --- /dev/null +++ b/tests/test_deep_agent.py @@ -0,0 +1,81 @@ +"""Unit tests for single-agent deep_agent flows with mocked tool results.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +from langchain_core.messages import AIMessage, ToolMessage + +from app.core.deep_agent import run_floating_stream, run_home + + +class _FakeTool: + name = "list_tasks" + + async def ainvoke(self, args): + return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args} + + +class _FakeLLM: + def __init__(self) -> None: + self.calls = 0 + + def bind_tools(self, _tools): + return self + + async def ainvoke(self, messages): + self.calls += 1 + if self.calls == 1: + return AIMessage( + content="", + tool_calls=[ + { + "id": "call-1", + "name": "list_tasks", + "args": {"project_id": "proj-1"}, + } + ], + ) + + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] + assert tool_messages, "Expected at least one tool message" + return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}") + + async def astream(self, _messages): + yield SimpleNamespace(content="stream-") + yield SimpleNamespace(content="ok") + + +@pytest.mark.asyncio +async def test_run_home_uses_mocked_tool_result(): + fake_llm = _FakeLLM() + + with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( + "app.core.deep_agent._all_tools", return_value=[_FakeTool()] + ): + out = await run_home("user-1", "list my tasks", {}) + + assert "Final answer from mocked tool" in out + assert "Mock Task" in out + + +@pytest.mark.asyncio +async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result(): + fake_llm = _FakeLLM() + + with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( + "app.core.deep_agent._all_tools", return_value=[_FakeTool()] + ): + events = [] + async for event in run_floating_stream( + "user-1", + "show me timeline updates", + {"scope": {"type": "timeline", "id": "tl-1"}}, + ): + events.append(event) + + assert events[0] == ("floating_domain", "timelines") + assert ("token", "stream-") in events + assert ("token", "ok") in events