|
|
|
|
@@ -5,9 +5,10 @@ from __future__ import annotations
|
|
|
|
|
import asyncio
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import operator
|
|
|
|
|
import re
|
|
|
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
|
|
|
from typing import Any, Literal, TypedDict
|
|
|
|
|
import operator
|
|
|
|
|
from typing import Annotated, Any, Literal, TypedDict
|
|
|
|
|
|
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
|
@@ -22,11 +23,14 @@ from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
|
|
|
|
|
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
|
|
|
|
|
from app.core.llm import get_llm
|
|
|
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
|
|
|
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
|
|
|
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
|
|
|
|
from app.db import async_session
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# Quick test switch: home requests run as one agent with all tools.
|
|
|
|
|
HOME_SINGLE_AGENT_TEST_MODE = True
|
|
|
|
|
|
|
|
|
|
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
|
|
|
|
|
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
|
|
|
|
|
|
|
|
|
@@ -56,6 +60,7 @@ class WorkerResult(TypedDict):
|
|
|
|
|
instruction: str
|
|
|
|
|
response: str
|
|
|
|
|
entity_ids: dict[str, list[str]]
|
|
|
|
|
facts: dict[str, Any]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OrchestratorState(TypedDict, total=False):
|
|
|
|
|
@@ -71,7 +76,7 @@ class OrchestratorState(TypedDict, total=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphState(OrchestratorState):
|
|
|
|
|
worker_results: list[WorkerResult]
|
|
|
|
|
worker_results: Annotated[list[WorkerResult], operator.add]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReducerState(OrchestratorState):
|
|
|
|
|
@@ -116,19 +121,21 @@ WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
|
|
|
|
|
_HOME_ORCHESTRATOR_SYSTEM = (
|
|
|
|
|
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
|
|
|
|
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
|
|
|
|
"Return only the workers needed."
|
|
|
|
|
"Return JSON only with keys: tasks, floating_domain, memory_updates."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
|
|
|
|
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
|
|
|
|
"as one of: tasks, projects, notes, timelines."
|
|
|
|
|
"as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_HOME_SYNTH_SYSTEM = (
|
|
|
|
|
"You are the final response synthesizer. Return markdown only. "
|
|
|
|
|
"Embed inline component tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
|
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, and <chart>{json}</chart>. "
|
|
|
|
|
"Only include IDs that are truly relevant to the request."
|
|
|
|
|
"Only include IDs that are truly relevant to the request. "
|
|
|
|
|
"Never invent missing values. If facts include a non-null clientId for a project, "
|
|
|
|
|
"do not claim that the project has no owner/client."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_FLOATING_SYNTH_SYSTEM = (
|
|
|
|
|
@@ -136,6 +143,14 @@ _FLOATING_SYNTH_SYSTEM = (
|
|
|
|
|
"Return concise markdown and stay focused on the requested scope."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
|
|
|
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
|
|
|
|
"Always use tools for factual data retrieval before answering. "
|
|
|
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
|
|
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
|
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _as_text(content: Any) -> str:
|
|
|
|
|
if content is None:
|
|
|
|
|
@@ -178,7 +193,243 @@ def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
|
|
|
|
|
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
|
|
|
|
"""Best-effort extraction of the first JSON object from model output."""
|
|
|
|
|
stripped = text.strip()
|
|
|
|
|
if not stripped:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Common case: model returns raw JSON object.
|
|
|
|
|
try:
|
|
|
|
|
payload = json.loads(stripped)
|
|
|
|
|
if isinstance(payload, dict):
|
|
|
|
|
return payload
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Fenced JSON block fallback.
|
|
|
|
|
if "```" in stripped:
|
|
|
|
|
parts = stripped.split("```")
|
|
|
|
|
for part in parts:
|
|
|
|
|
candidate = part.strip()
|
|
|
|
|
if candidate.startswith("json"):
|
|
|
|
|
candidate = candidate[4:].strip()
|
|
|
|
|
try:
|
|
|
|
|
payload = json.loads(candidate)
|
|
|
|
|
if isinstance(payload, dict):
|
|
|
|
|
return payload
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
|
|
|
|
|
"""Normalize loose model JSON into a validated WorkerPlan."""
|
|
|
|
|
tasks_raw = payload.get("tasks")
|
|
|
|
|
tasks: list[WorkerTask] = []
|
|
|
|
|
|
|
|
|
|
if isinstance(tasks_raw, list):
|
|
|
|
|
for item in tasks_raw:
|
|
|
|
|
if not isinstance(item, dict):
|
|
|
|
|
continue
|
|
|
|
|
worker = item.get("worker")
|
|
|
|
|
instruction = item.get("instruction")
|
|
|
|
|
if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
|
|
|
|
|
tasks.append(WorkerTask(worker=worker, instruction=instruction))
|
|
|
|
|
|
|
|
|
|
if not tasks:
|
|
|
|
|
return _fallback_plan(message, floating)
|
|
|
|
|
|
|
|
|
|
domain = payload.get("floating_domain")
|
|
|
|
|
floating_domain: FloatingDomain | None = None
|
|
|
|
|
if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
|
|
|
|
|
floating_domain = domain # type: ignore[assignment]
|
|
|
|
|
elif floating:
|
|
|
|
|
floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
|
|
|
|
|
|
|
|
|
|
memory_updates: list[MemoryUpdate] = []
|
|
|
|
|
updates_raw = payload.get("memory_updates")
|
|
|
|
|
if isinstance(updates_raw, list):
|
|
|
|
|
for item in updates_raw:
|
|
|
|
|
if isinstance(item, dict):
|
|
|
|
|
key = item.get("key")
|
|
|
|
|
value = item.get("value")
|
|
|
|
|
if isinstance(key, str) and isinstance(value, str) and key and value:
|
|
|
|
|
memory_updates.append(MemoryUpdate(key=key, value=value))
|
|
|
|
|
|
|
|
|
|
return WorkerPlan(
|
|
|
|
|
tasks=tasks,
|
|
|
|
|
floating_domain=floating_domain,
|
|
|
|
|
memory_updates=memory_updates,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _needs_full_project_snapshot(message: str, floating: bool) -> bool:
|
|
|
|
|
"""Detect project status/update requests that should query all workers."""
|
|
|
|
|
if floating:
|
|
|
|
|
return False
|
|
|
|
|
lowered = message.lower()
|
|
|
|
|
has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"])
|
|
|
|
|
has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"])
|
|
|
|
|
return has_project and has_status_intent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_full_project_snapshot_plan(message: str) -> WorkerPlan:
|
|
|
|
|
"""Build a deterministic all-workers plan for project status snapshots."""
|
|
|
|
|
project_hint = (
|
|
|
|
|
"Use context.context.resolved_project_id when present as project_id. "
|
|
|
|
|
"Do not pass project names as project_id."
|
|
|
|
|
)
|
|
|
|
|
return WorkerPlan(
|
|
|
|
|
tasks=[
|
|
|
|
|
WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"),
|
|
|
|
|
WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"),
|
|
|
|
|
WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"),
|
|
|
|
|
WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _candidate_tokens(message: str) -> list[str]:
|
|
|
|
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
|
|
|
|
return [t for t in tokens if len(t) >= 3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
|
|
|
|
"""Resolve likely project UUID from user message using client project list."""
|
|
|
|
|
try:
|
|
|
|
|
result = await execute_on_client(action="select", table="projects")
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
rows = result.get("rows", [])
|
|
|
|
|
if not isinstance(rows, list) or not rows:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
tokens = _candidate_tokens(message)
|
|
|
|
|
scored: list[tuple[int, dict[str, Any]]] = []
|
|
|
|
|
for row in rows:
|
|
|
|
|
if not isinstance(row, dict):
|
|
|
|
|
continue
|
|
|
|
|
name = str(row.get("name", "")).lower()
|
|
|
|
|
score = sum(1 for token in tokens if token in name)
|
|
|
|
|
if score > 0:
|
|
|
|
|
scored.append((score, row))
|
|
|
|
|
|
|
|
|
|
if not scored:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
scored.sort(key=lambda item: item[0], reverse=True)
|
|
|
|
|
top_score = scored[0][0]
|
|
|
|
|
top_rows = [row for score, row in scored if score == top_score]
|
|
|
|
|
if len(top_rows) != 1:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
project_id = top_rows[0].get("id")
|
|
|
|
|
return project_id if isinstance(project_id, str) else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
|
|
|
|
"""Resolve and inject project_id hints for home flows."""
|
|
|
|
|
prepared = dict(context)
|
|
|
|
|
if _needs_full_project_snapshot(message, floating=False):
|
|
|
|
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
|
|
|
|
if resolved_project_id:
|
|
|
|
|
prepared["resolved_project_id"] = resolved_project_id
|
|
|
|
|
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200])
|
|
|
|
|
return prepared
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _all_tools() -> list[Any]:
|
|
|
|
|
tools: list[Any] = []
|
|
|
|
|
for config in WORKER_CONFIG.values():
|
|
|
|
|
tools.extend(config["tools"])
|
|
|
|
|
return tools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _run_home_single_agent(
|
|
|
|
|
user_id: str,
|
|
|
|
|
message: str,
|
|
|
|
|
context: dict[str, Any],
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Single-agent test mode: one loop with all tools."""
|
|
|
|
|
prepared_context = await _prepare_home_context(message, context)
|
|
|
|
|
|
|
|
|
|
llm = get_llm()
|
|
|
|
|
tools = _all_tools()
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
|
messages: list[Any] = [
|
|
|
|
|
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
|
|
|
|
|
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for _ in range(6):
|
|
|
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
|
|
|
messages.append(response)
|
|
|
|
|
if not response.tool_calls:
|
|
|
|
|
return _as_text(response.content)
|
|
|
|
|
|
|
|
|
|
tool_map = {t.name: t for t in tools}
|
|
|
|
|
for call in response.tool_calls:
|
|
|
|
|
tool_fn = tool_map.get(call["name"])
|
|
|
|
|
if tool_fn is None:
|
|
|
|
|
tool_output = f"Unknown tool: {call['name']}"
|
|
|
|
|
else:
|
|
|
|
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
|
|
|
|
|
|
final = await llm.ainvoke(messages)
|
|
|
|
|
return _as_text(final.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _run_home_single_agent_stream(
|
|
|
|
|
user_id: str,
|
|
|
|
|
message: str,
|
|
|
|
|
context: dict[str, Any],
|
|
|
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
|
|
|
"""Streaming variant for single-agent home test mode."""
|
|
|
|
|
prepared_context = await _prepare_home_context(message, context)
|
|
|
|
|
|
|
|
|
|
llm = get_llm()
|
|
|
|
|
tools = _all_tools()
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
|
messages: list[Any] = [
|
|
|
|
|
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
|
|
|
|
|
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for _ in range(6):
|
|
|
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
|
|
|
messages.append(response)
|
|
|
|
|
if not response.tool_calls:
|
|
|
|
|
async for chunk in llm.astream(messages):
|
|
|
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
|
|
|
if token:
|
|
|
|
|
yield "token", token
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
tool_map = {t.name: t for t in tools}
|
|
|
|
|
for call in response.tool_calls:
|
|
|
|
|
tool_fn = tool_map.get(call["name"])
|
|
|
|
|
if tool_fn is None:
|
|
|
|
|
tool_output = f"Unknown tool: {call['name']}"
|
|
|
|
|
else:
|
|
|
|
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
|
|
|
|
|
|
async for chunk in llm.astream(messages):
|
|
|
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
|
|
|
if token:
|
|
|
|
|
yield "token", token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
|
|
|
|
if _needs_full_project_snapshot(message, floating):
|
|
|
|
|
logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200])
|
|
|
|
|
return _build_full_project_snapshot_plan(message)
|
|
|
|
|
|
|
|
|
|
llm = get_llm()
|
|
|
|
|
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
|
|
|
|
|
|
|
|
|
@@ -189,18 +440,34 @@ async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool)
|
|
|
|
|
}
|
|
|
|
|
messages = [
|
|
|
|
|
SystemMessage(content=system),
|
|
|
|
|
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
|
|
|
|
|
HumanMessage(
|
|
|
|
|
content=(
|
|
|
|
|
"Create a valid JSON object with this exact structure:\n"
|
|
|
|
|
'{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
|
|
|
|
|
'"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
|
|
|
|
|
"Rules:\n"
|
|
|
|
|
"- tasks must include at least one entry when possible\n"
|
|
|
|
|
"- use floating_domain only when relevant\n"
|
|
|
|
|
"- output JSON only (no markdown, no prose)\n\n"
|
|
|
|
|
f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
structured_llm = llm.with_structured_output(WorkerPlan)
|
|
|
|
|
plan = await structured_llm.ainvoke(messages)
|
|
|
|
|
if isinstance(plan, WorkerPlan):
|
|
|
|
|
if not plan.tasks:
|
|
|
|
|
return _fallback_plan(message, floating)
|
|
|
|
|
response = await llm.ainvoke(messages)
|
|
|
|
|
payload = _extract_json_object(_as_text(response.content))
|
|
|
|
|
if payload is None:
|
|
|
|
|
raise ValueError("planner returned non-JSON output")
|
|
|
|
|
plan = _coerce_plan(payload, message, floating)
|
|
|
|
|
logger.info(
|
|
|
|
|
"deep_agent: planner produced tasks=%s floating=%s",
|
|
|
|
|
[t.worker for t in plan.tasks],
|
|
|
|
|
plan.floating_domain,
|
|
|
|
|
)
|
|
|
|
|
return plan
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
|
|
|
|
|
logger.warning("deep_agent: planner failed, using fallback: %s", exc)
|
|
|
|
|
|
|
|
|
|
return _fallback_plan(message, floating)
|
|
|
|
|
|
|
|
|
|
@@ -243,6 +510,64 @@ def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[st
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]:
|
|
|
|
|
"""Extract small, structured facts for the synthesizer to avoid hallucinations."""
|
|
|
|
|
facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []}
|
|
|
|
|
|
|
|
|
|
for item in tool_results:
|
|
|
|
|
table = item.get("table")
|
|
|
|
|
payload = item.get("data") or {}
|
|
|
|
|
|
|
|
|
|
rows: list[dict[str, Any]] = []
|
|
|
|
|
row = payload.get("row")
|
|
|
|
|
if isinstance(row, dict):
|
|
|
|
|
rows.append(row)
|
|
|
|
|
if isinstance(payload.get("rows"), list):
|
|
|
|
|
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
|
|
|
|
|
|
|
|
|
|
if table == "projects":
|
|
|
|
|
for r in rows:
|
|
|
|
|
facts["projects"].append(
|
|
|
|
|
{
|
|
|
|
|
"id": r.get("id"),
|
|
|
|
|
"name": r.get("name"),
|
|
|
|
|
"status": r.get("status"),
|
|
|
|
|
"clientId": r.get("clientId"),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
elif table == "tasks":
|
|
|
|
|
for r in rows:
|
|
|
|
|
facts["tasks"].append(
|
|
|
|
|
{
|
|
|
|
|
"id": r.get("id"),
|
|
|
|
|
"title": r.get("title"),
|
|
|
|
|
"status": r.get("status"),
|
|
|
|
|
"projectId": r.get("projectId"),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
elif table == "notes":
|
|
|
|
|
for r in rows:
|
|
|
|
|
facts["notes"].append(
|
|
|
|
|
{
|
|
|
|
|
"id": r.get("id"),
|
|
|
|
|
"title": r.get("title"),
|
|
|
|
|
"projectId": r.get("projectId"),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
elif table == "timelines":
|
|
|
|
|
for r in rows:
|
|
|
|
|
facts["timelines"].append(
|
|
|
|
|
{
|
|
|
|
|
"id": r.get("id"),
|
|
|
|
|
"title": r.get("title"),
|
|
|
|
|
"date": r.get("date"),
|
|
|
|
|
"projectId": r.get("projectId"),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return facts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _run_tool_loop(
|
|
|
|
|
worker: WorkerName,
|
|
|
|
|
instruction: str,
|
|
|
|
|
@@ -254,10 +579,45 @@ async def _run_tool_loop(
|
|
|
|
|
llm = get_llm()
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
|
|
|
|
|
|
|
|
resolved_project_id = None
|
|
|
|
|
ctx = context.get("context", {}) if isinstance(context, dict) else {}
|
|
|
|
|
if isinstance(ctx, dict):
|
|
|
|
|
rpid = ctx.get("resolved_project_id")
|
|
|
|
|
if isinstance(rpid, str) and rpid:
|
|
|
|
|
resolved_project_id = rpid
|
|
|
|
|
|
|
|
|
|
mandatory_tool_policy = ""
|
|
|
|
|
if resolved_project_id:
|
|
|
|
|
if worker == "project_agent":
|
|
|
|
|
mandatory_tool_policy = (
|
|
|
|
|
"MANDATORY TOOL POLICY:\n"
|
|
|
|
|
f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n"
|
|
|
|
|
"- Optionally call list_projects afterward only if needed for disambiguation.\n\n"
|
|
|
|
|
)
|
|
|
|
|
elif worker == "task_agent":
|
|
|
|
|
mandatory_tool_policy = (
|
|
|
|
|
"MANDATORY TOOL POLICY:\n"
|
|
|
|
|
f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n"
|
|
|
|
|
"- Do not use project name as project_id.\n\n"
|
|
|
|
|
)
|
|
|
|
|
elif worker == "timeline_agent":
|
|
|
|
|
mandatory_tool_policy = (
|
|
|
|
|
"MANDATORY TOOL POLICY:\n"
|
|
|
|
|
f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n"
|
|
|
|
|
"- Do not use project name as project_id.\n\n"
|
|
|
|
|
)
|
|
|
|
|
elif worker == "note_agent":
|
|
|
|
|
mandatory_tool_policy = (
|
|
|
|
|
"MANDATORY TOOL POLICY:\n"
|
|
|
|
|
f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n"
|
|
|
|
|
"- Do not use project name as project_id.\n\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
messages: list[Any] = [
|
|
|
|
|
SystemMessage(content=worker_prompt),
|
|
|
|
|
HumanMessage(
|
|
|
|
|
content=(
|
|
|
|
|
mandatory_tool_policy +
|
|
|
|
|
"Worker instruction:\n"
|
|
|
|
|
f"{instruction}\n\n"
|
|
|
|
|
"Conversation context:\n"
|
|
|
|
|
@@ -278,12 +638,38 @@ async def _run_tool_loop(
|
|
|
|
|
|
|
|
|
|
tool_map = {t.name: t for t in tools}
|
|
|
|
|
for call in response.tool_calls:
|
|
|
|
|
call_id = str(call.get("id", ""))
|
|
|
|
|
call_name = str(call.get("name", ""))
|
|
|
|
|
call_args = call.get("args", {})
|
|
|
|
|
logger.info(
|
|
|
|
|
"deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
|
|
|
worker,
|
|
|
|
|
call_id,
|
|
|
|
|
call_name,
|
|
|
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
tool_fn = tool_map.get(call["name"])
|
|
|
|
|
if tool_fn is None:
|
|
|
|
|
tool_output = f"Unknown tool: {call['name']}"
|
|
|
|
|
else:
|
|
|
|
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
|
|
|
|
|
|
|
|
|
tool_output_text = str(tool_output)
|
|
|
|
|
logger.info(
|
|
|
|
|
"deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
|
|
|
worker,
|
|
|
|
|
call_id,
|
|
|
|
|
call_name,
|
|
|
|
|
tool_output_text[:1200],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
|
logger.info(
|
|
|
|
|
"deep_agent: worker=%s appended ToolMessage tool_call_id=%s",
|
|
|
|
|
worker,
|
|
|
|
|
call_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
structured_llm = llm.with_structured_output(WorkerSummary)
|
|
|
|
|
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
|
|
|
|
|
@@ -303,11 +689,18 @@ def _worker_node(worker: WorkerName):
|
|
|
|
|
return {"worker_results": []}
|
|
|
|
|
|
|
|
|
|
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
|
|
|
|
|
logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240])
|
|
|
|
|
worker_context = {
|
|
|
|
|
"memory": state.get("memory_context", {}),
|
|
|
|
|
"context": state.get("context", {}),
|
|
|
|
|
}
|
|
|
|
|
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
|
|
|
|
|
logger.info(
|
|
|
|
|
"deep_agent: worker=%s complete tool_calls=%d entity_counts=%s",
|
|
|
|
|
worker,
|
|
|
|
|
len(tool_results),
|
|
|
|
|
{k: len(v) for k, v in _extract_entity_ids(tool_results).items()},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"worker_results": [
|
|
|
|
|
@@ -316,6 +709,7 @@ def _worker_node(worker: WorkerName):
|
|
|
|
|
"instruction": instruction,
|
|
|
|
|
"response": response,
|
|
|
|
|
"entity_ids": _extract_entity_ids(tool_results),
|
|
|
|
|
"facts": _extract_facts(tool_results),
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
@@ -333,6 +727,7 @@ def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
|
|
|
|
|
"instruction": result.get("instruction"),
|
|
|
|
|
"response": result.get("response"),
|
|
|
|
|
"entity_ids": result.get("entity_ids", {}),
|
|
|
|
|
"facts": result.get("facts", {}),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -399,14 +794,25 @@ async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
|
|
|
|
if state.get("plan"):
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
|
|
|
|
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
|
|
|
|
user_message = str(state.get("user_message", ""))
|
|
|
|
|
base_context = dict(state.get("context", {}))
|
|
|
|
|
context = {**base_context, **state.get("memory_context", {})}
|
|
|
|
|
|
|
|
|
|
if _needs_full_project_snapshot(user_message, floating=False):
|
|
|
|
|
resolved_project_id = await _resolve_project_id_from_message(user_message)
|
|
|
|
|
if resolved_project_id:
|
|
|
|
|
base_context["resolved_project_id"] = resolved_project_id
|
|
|
|
|
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200])
|
|
|
|
|
plan = _build_full_project_snapshot_plan(user_message)
|
|
|
|
|
else:
|
|
|
|
|
plan = await _plan_with_llm(user_message, context, floating=False)
|
|
|
|
|
|
|
|
|
|
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"plan": [task.model_dump() for task in plan.tasks],
|
|
|
|
|
"memory_context": new_memory
|
|
|
|
|
"memory_context": new_memory,
|
|
|
|
|
"context": base_context,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -470,6 +876,9 @@ FLOATING_GRAPH = _build_graph(floating=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|
|
|
|
if HOME_SINGLE_AGENT_TEST_MODE:
|
|
|
|
|
return await _run_home_single_agent(user_id, message, context)
|
|
|
|
|
|
|
|
|
|
state = await HOME_GRAPH.ainvoke(
|
|
|
|
|
{
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
@@ -505,6 +914,11 @@ async def run_home_stream(
|
|
|
|
|
message: str,
|
|
|
|
|
context: dict[str, Any],
|
|
|
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
|
|
|
if HOME_SINGLE_AGENT_TEST_MODE:
|
|
|
|
|
async for event in _run_home_single_agent_stream(user_id, message, context):
|
|
|
|
|
yield event
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
state_input = {
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
"user_message": message,
|
|
|
|
|
|