Compare commits
2 Commits
f7404b6f66
...
5b55f1292a
| Author | SHA1 | Date | |
|---|---|---|---|
| 5b55f1292a | |||
| 5bc9ea6cd6 |
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -9,6 +10,14 @@ from langchain_core.tools import tool
|
|||||||
from app.core.llm import embed
|
from app.core.llm import embed
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
NOTE_SYSTEM_PROMPT = (
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
@@ -19,6 +28,7 @@ NOTE_SYSTEM_PROMPT = (
|
|||||||
" before appending or replacing sections\n"
|
" before appending or replacing sections\n"
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
" when the user is working within a specific project\n"
|
" when the user is working within a specific project\n"
|
||||||
|
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
" is already in the note (retrieved via get_note)."
|
" is already in the note (retrieved via get_note)."
|
||||||
)
|
)
|
||||||
@@ -27,10 +37,11 @@ NOTE_SYSTEM_PROMPT = (
|
|||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
|
|||||||
@@ -3,12 +3,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
TASK_SYSTEM_PROMPT = (
|
||||||
"You are a task management assistant for a project workspace.\n"
|
"You are a task management assistant for a project workspace.\n"
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
@@ -39,11 +48,12 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": normalized_project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -205,8 +215,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result.get("row", {})
|
||||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
row_author = row.get("author", author)
|
||||||
|
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|||||||
@@ -2,17 +2,27 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
TIMELINE_SYSTEM_PROMPT = (
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
"Rules:\n"
|
"Rules:\n"
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||||
@@ -25,10 +35,11 @@ TIMELINE_SYSTEM_PROMPT = (
|
|||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import re
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from typing import Any, Literal, TypedDict
|
import operator
|
||||||
|
from typing import Annotated, Any, Literal, TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -22,11 +23,14 @@ from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
|
|||||||
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Quick test switch: home requests run as one agent with all tools.
|
||||||
|
HOME_SINGLE_AGENT_TEST_MODE = True
|
||||||
|
|
||||||
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
|
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
|
||||||
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
||||||
|
|
||||||
@@ -56,6 +60,7 @@ class WorkerResult(TypedDict):
|
|||||||
instruction: str
|
instruction: str
|
||||||
response: str
|
response: str
|
||||||
entity_ids: dict[str, list[str]]
|
entity_ids: dict[str, list[str]]
|
||||||
|
facts: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class OrchestratorState(TypedDict, total=False):
|
class OrchestratorState(TypedDict, total=False):
|
||||||
@@ -71,7 +76,7 @@ class OrchestratorState(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
class GraphState(OrchestratorState):
|
class GraphState(OrchestratorState):
|
||||||
worker_results: list[WorkerResult]
|
worker_results: Annotated[list[WorkerResult], operator.add]
|
||||||
|
|
||||||
|
|
||||||
class ReducerState(OrchestratorState):
|
class ReducerState(OrchestratorState):
|
||||||
@@ -116,19 +121,21 @@ WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
|
|||||||
_HOME_ORCHESTRATOR_SYSTEM = (
|
_HOME_ORCHESTRATOR_SYSTEM = (
|
||||||
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
||||||
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
||||||
"Return only the workers needed."
|
"Return JSON only with keys: tasks, floating_domain, memory_updates."
|
||||||
)
|
)
|
||||||
|
|
||||||
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
||||||
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
||||||
"as one of: tasks, projects, notes, timelines."
|
"as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
|
||||||
)
|
)
|
||||||
|
|
||||||
_HOME_SYNTH_SYSTEM = (
|
_HOME_SYNTH_SYSTEM = (
|
||||||
"You are the final response synthesizer. Return markdown only. "
|
"You are the final response synthesizer. Return markdown only. "
|
||||||
"Embed inline component tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
"Embed inline component tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, and <chart>{json}</chart>. "
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, and <chart>{json}</chart>. "
|
||||||
"Only include IDs that are truly relevant to the request."
|
"Only include IDs that are truly relevant to the request. "
|
||||||
|
"Never invent missing values. If facts include a non-null clientId for a project, "
|
||||||
|
"do not claim that the project has no owner/client."
|
||||||
)
|
)
|
||||||
|
|
||||||
_FLOATING_SYNTH_SYSTEM = (
|
_FLOATING_SYNTH_SYSTEM = (
|
||||||
@@ -136,6 +143,14 @@ _FLOATING_SYNTH_SYSTEM = (
|
|||||||
"Return concise markdown and stay focused on the requested scope."
|
"Return concise markdown and stay focused on the requested scope."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||||
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
||||||
|
"Always use tools for factual data retrieval before answering. "
|
||||||
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||||
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
if content is None:
|
if content is None:
|
||||||
@@ -178,7 +193,243 @@ def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
|
|||||||
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
"""Best-effort extraction of the first JSON object from model output."""
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Common case: model returns raw JSON object.
|
||||||
|
try:
|
||||||
|
payload = json.loads(stripped)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
return payload
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fenced JSON block fallback.
|
||||||
|
if "```" in stripped:
|
||||||
|
parts = stripped.split("```")
|
||||||
|
for part in parts:
|
||||||
|
candidate = part.strip()
|
||||||
|
if candidate.startswith("json"):
|
||||||
|
candidate = candidate[4:].strip()
|
||||||
|
try:
|
||||||
|
payload = json.loads(candidate)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
return payload
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
|
||||||
|
"""Normalize loose model JSON into a validated WorkerPlan."""
|
||||||
|
tasks_raw = payload.get("tasks")
|
||||||
|
tasks: list[WorkerTask] = []
|
||||||
|
|
||||||
|
if isinstance(tasks_raw, list):
|
||||||
|
for item in tasks_raw:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
worker = item.get("worker")
|
||||||
|
instruction = item.get("instruction")
|
||||||
|
if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
|
||||||
|
tasks.append(WorkerTask(worker=worker, instruction=instruction))
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
return _fallback_plan(message, floating)
|
||||||
|
|
||||||
|
domain = payload.get("floating_domain")
|
||||||
|
floating_domain: FloatingDomain | None = None
|
||||||
|
if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
|
||||||
|
floating_domain = domain # type: ignore[assignment]
|
||||||
|
elif floating:
|
||||||
|
floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
|
||||||
|
|
||||||
|
memory_updates: list[MemoryUpdate] = []
|
||||||
|
updates_raw = payload.get("memory_updates")
|
||||||
|
if isinstance(updates_raw, list):
|
||||||
|
for item in updates_raw:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
key = item.get("key")
|
||||||
|
value = item.get("value")
|
||||||
|
if isinstance(key, str) and isinstance(value, str) and key and value:
|
||||||
|
memory_updates.append(MemoryUpdate(key=key, value=value))
|
||||||
|
|
||||||
|
return WorkerPlan(
|
||||||
|
tasks=tasks,
|
||||||
|
floating_domain=floating_domain,
|
||||||
|
memory_updates=memory_updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_full_project_snapshot(message: str, floating: bool) -> bool:
|
||||||
|
"""Detect project status/update requests that should query all workers."""
|
||||||
|
if floating:
|
||||||
|
return False
|
||||||
|
lowered = message.lower()
|
||||||
|
has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"])
|
||||||
|
has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"])
|
||||||
|
return has_project and has_status_intent
|
||||||
|
|
||||||
|
|
||||||
|
def _build_full_project_snapshot_plan(message: str) -> WorkerPlan:
|
||||||
|
"""Build a deterministic all-workers plan for project status snapshots."""
|
||||||
|
project_hint = (
|
||||||
|
"Use context.context.resolved_project_id when present as project_id. "
|
||||||
|
"Do not pass project names as project_id."
|
||||||
|
)
|
||||||
|
return WorkerPlan(
|
||||||
|
tasks=[
|
||||||
|
WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"),
|
||||||
|
WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"),
|
||||||
|
WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"),
|
||||||
|
WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_tokens(message: str) -> list[str]:
|
||||||
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||||
|
return [t for t in tokens if len(t) >= 3]
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||||
|
"""Resolve likely project UUID from user message using client project list."""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not isinstance(rows, list) or not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = _candidate_tokens(message)
|
||||||
|
scored: list[tuple[int, dict[str, Any]]] = []
|
||||||
|
for row in rows:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
name = str(row.get("name", "")).lower()
|
||||||
|
score = sum(1 for token in tokens if token in name)
|
||||||
|
if score > 0:
|
||||||
|
scored.append((score, row))
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scored.sort(key=lambda item: item[0], reverse=True)
|
||||||
|
top_score = scored[0][0]
|
||||||
|
top_rows = [row for score, row in scored if score == top_score]
|
||||||
|
if len(top_rows) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
project_id = top_rows[0].get("id")
|
||||||
|
return project_id if isinstance(project_id, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Resolve and inject project_id hints for home flows."""
|
||||||
|
prepared = dict(context)
|
||||||
|
if _needs_full_project_snapshot(message, floating=False):
|
||||||
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||||
|
if resolved_project_id:
|
||||||
|
prepared["resolved_project_id"] = resolved_project_id
|
||||||
|
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200])
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
|
||||||
|
def _all_tools() -> list[Any]:
|
||||||
|
tools: list[Any] = []
|
||||||
|
for config in WORKER_CONFIG.values():
|
||||||
|
tools.extend(config["tools"])
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_home_single_agent(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Single-agent test mode: one loop with all tools."""
|
||||||
|
prepared_context = await _prepare_home_context(message, context)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
tools = _all_tools()
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
|
||||||
|
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in range(6):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
tool_map = {t.name: t for t in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_fn = tool_map.get(call["name"])
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call['name']}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_home_single_agent_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Streaming variant for single-agent home test mode."""
|
||||||
|
prepared_context = await _prepare_home_context(message, context)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
tools = _all_tools()
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
|
||||||
|
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in range(6):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
if not response.tool_calls:
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
yield "token", token
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_map = {t.name: t for t in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_fn = tool_map.get(call["name"])
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call['name']}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
|
if token:
|
||||||
|
yield "token", token
|
||||||
|
|
||||||
|
|
||||||
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
||||||
|
if _needs_full_project_snapshot(message, floating):
|
||||||
|
logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200])
|
||||||
|
return _build_full_project_snapshot_plan(message)
|
||||||
|
|
||||||
llm = get_llm()
|
llm = get_llm()
|
||||||
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
||||||
|
|
||||||
@@ -189,18 +440,34 @@ async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool)
|
|||||||
}
|
}
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=system),
|
SystemMessage(content=system),
|
||||||
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
"Create a valid JSON object with this exact structure:\n"
|
||||||
|
'{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
|
||||||
|
'"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- tasks must include at least one entry when possible\n"
|
||||||
|
"- use floating_domain only when relevant\n"
|
||||||
|
"- output JSON only (no markdown, no prose)\n\n"
|
||||||
|
f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
|
||||||
|
)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
structured_llm = llm.with_structured_output(WorkerPlan)
|
response = await llm.ainvoke(messages)
|
||||||
plan = await structured_llm.ainvoke(messages)
|
payload = _extract_json_object(_as_text(response.content))
|
||||||
if isinstance(plan, WorkerPlan):
|
if payload is None:
|
||||||
if not plan.tasks:
|
raise ValueError("planner returned non-JSON output")
|
||||||
return _fallback_plan(message, floating)
|
plan = _coerce_plan(payload, message, floating)
|
||||||
return plan
|
logger.info(
|
||||||
|
"deep_agent: planner produced tasks=%s floating=%s",
|
||||||
|
[t.worker for t in plan.tasks],
|
||||||
|
plan.floating_domain,
|
||||||
|
)
|
||||||
|
return plan
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
|
logger.warning("deep_agent: planner failed, using fallback: %s", exc)
|
||||||
|
|
||||||
return _fallback_plan(message, floating)
|
return _fallback_plan(message, floating)
|
||||||
|
|
||||||
@@ -243,6 +510,64 @@ def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[st
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
"""Extract small, structured facts for the synthesizer to avoid hallucinations."""
|
||||||
|
facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []}
|
||||||
|
|
||||||
|
for item in tool_results:
|
||||||
|
table = item.get("table")
|
||||||
|
payload = item.get("data") or {}
|
||||||
|
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
row = payload.get("row")
|
||||||
|
if isinstance(row, dict):
|
||||||
|
rows.append(row)
|
||||||
|
if isinstance(payload.get("rows"), list):
|
||||||
|
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
|
||||||
|
|
||||||
|
if table == "projects":
|
||||||
|
for r in rows:
|
||||||
|
facts["projects"].append(
|
||||||
|
{
|
||||||
|
"id": r.get("id"),
|
||||||
|
"name": r.get("name"),
|
||||||
|
"status": r.get("status"),
|
||||||
|
"clientId": r.get("clientId"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif table == "tasks":
|
||||||
|
for r in rows:
|
||||||
|
facts["tasks"].append(
|
||||||
|
{
|
||||||
|
"id": r.get("id"),
|
||||||
|
"title": r.get("title"),
|
||||||
|
"status": r.get("status"),
|
||||||
|
"projectId": r.get("projectId"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif table == "notes":
|
||||||
|
for r in rows:
|
||||||
|
facts["notes"].append(
|
||||||
|
{
|
||||||
|
"id": r.get("id"),
|
||||||
|
"title": r.get("title"),
|
||||||
|
"projectId": r.get("projectId"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif table == "timelines":
|
||||||
|
for r in rows:
|
||||||
|
facts["timelines"].append(
|
||||||
|
{
|
||||||
|
"id": r.get("id"),
|
||||||
|
"title": r.get("title"),
|
||||||
|
"date": r.get("date"),
|
||||||
|
"projectId": r.get("projectId"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return facts
|
||||||
|
|
||||||
|
|
||||||
async def _run_tool_loop(
|
async def _run_tool_loop(
|
||||||
worker: WorkerName,
|
worker: WorkerName,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
@@ -254,10 +579,45 @@ async def _run_tool_loop(
|
|||||||
llm = get_llm()
|
llm = get_llm()
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
|
resolved_project_id = None
|
||||||
|
ctx = context.get("context", {}) if isinstance(context, dict) else {}
|
||||||
|
if isinstance(ctx, dict):
|
||||||
|
rpid = ctx.get("resolved_project_id")
|
||||||
|
if isinstance(rpid, str) and rpid:
|
||||||
|
resolved_project_id = rpid
|
||||||
|
|
||||||
|
mandatory_tool_policy = ""
|
||||||
|
if resolved_project_id:
|
||||||
|
if worker == "project_agent":
|
||||||
|
mandatory_tool_policy = (
|
||||||
|
"MANDATORY TOOL POLICY:\n"
|
||||||
|
f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n"
|
||||||
|
"- Optionally call list_projects afterward only if needed for disambiguation.\n\n"
|
||||||
|
)
|
||||||
|
elif worker == "task_agent":
|
||||||
|
mandatory_tool_policy = (
|
||||||
|
"MANDATORY TOOL POLICY:\n"
|
||||||
|
f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n"
|
||||||
|
"- Do not use project name as project_id.\n\n"
|
||||||
|
)
|
||||||
|
elif worker == "timeline_agent":
|
||||||
|
mandatory_tool_policy = (
|
||||||
|
"MANDATORY TOOL POLICY:\n"
|
||||||
|
f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n"
|
||||||
|
"- Do not use project name as project_id.\n\n"
|
||||||
|
)
|
||||||
|
elif worker == "note_agent":
|
||||||
|
mandatory_tool_policy = (
|
||||||
|
"MANDATORY TOOL POLICY:\n"
|
||||||
|
f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n"
|
||||||
|
"- Do not use project name as project_id.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=worker_prompt),
|
SystemMessage(content=worker_prompt),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=(
|
content=(
|
||||||
|
mandatory_tool_policy +
|
||||||
"Worker instruction:\n"
|
"Worker instruction:\n"
|
||||||
f"{instruction}\n\n"
|
f"{instruction}\n\n"
|
||||||
"Conversation context:\n"
|
"Conversation context:\n"
|
||||||
@@ -278,12 +638,38 @@ async def _run_tool_loop(
|
|||||||
|
|
||||||
tool_map = {t.name: t for t in tools}
|
tool_map = {t.name: t for t in tools}
|
||||||
for call in response.tool_calls:
|
for call in response.tool_calls:
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||||
|
worker,
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
tool_fn = tool_map.get(call["name"])
|
||||||
if tool_fn is None:
|
if tool_fn is None:
|
||||||
tool_output = f"Unknown tool: {call['name']}"
|
tool_output = f"Unknown tool: {call['name']}"
|
||||||
else:
|
else:
|
||||||
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
tool_output = await tool_fn.ainvoke(call.get("args", {}))
|
||||||
|
|
||||||
|
tool_output_text = str(tool_output)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||||
|
worker,
|
||||||
|
call_id,
|
||||||
|
call_name,
|
||||||
|
tool_output_text[:1200],
|
||||||
|
)
|
||||||
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: worker=%s appended ToolMessage tool_call_id=%s",
|
||||||
|
worker,
|
||||||
|
call_id,
|
||||||
|
)
|
||||||
|
|
||||||
structured_llm = llm.with_structured_output(WorkerSummary)
|
structured_llm = llm.with_structured_output(WorkerSummary)
|
||||||
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
|
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
|
||||||
@@ -303,11 +689,18 @@ def _worker_node(worker: WorkerName):
|
|||||||
return {"worker_results": []}
|
return {"worker_results": []}
|
||||||
|
|
||||||
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
|
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
|
||||||
|
logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240])
|
||||||
worker_context = {
|
worker_context = {
|
||||||
"memory": state.get("memory_context", {}),
|
"memory": state.get("memory_context", {}),
|
||||||
"context": state.get("context", {}),
|
"context": state.get("context", {}),
|
||||||
}
|
}
|
||||||
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
|
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
|
||||||
|
logger.info(
|
||||||
|
"deep_agent: worker=%s complete tool_calls=%d entity_counts=%s",
|
||||||
|
worker,
|
||||||
|
len(tool_results),
|
||||||
|
{k: len(v) for k, v in _extract_entity_ids(tool_results).items()},
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"worker_results": [
|
"worker_results": [
|
||||||
@@ -316,6 +709,7 @@ def _worker_node(worker: WorkerName):
|
|||||||
"instruction": instruction,
|
"instruction": instruction,
|
||||||
"response": response,
|
"response": response,
|
||||||
"entity_ids": _extract_entity_ids(tool_results),
|
"entity_ids": _extract_entity_ids(tool_results),
|
||||||
|
"facts": _extract_facts(tool_results),
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -333,6 +727,7 @@ def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
|
|||||||
"instruction": result.get("instruction"),
|
"instruction": result.get("instruction"),
|
||||||
"response": result.get("response"),
|
"response": result.get("response"),
|
||||||
"entity_ids": result.get("entity_ids", {}),
|
"entity_ids": result.get("entity_ids", {}),
|
||||||
|
"facts": result.get("facts", {}),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -399,14 +794,25 @@ async def _orchestrator_node_home(state: GraphState) -> GraphState:
|
|||||||
if state.get("plan"):
|
if state.get("plan"):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
context = {**state.get("context", {}), **state.get("memory_context", {})}
|
user_message = str(state.get("user_message", ""))
|
||||||
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
|
base_context = dict(state.get("context", {}))
|
||||||
|
context = {**base_context, **state.get("memory_context", {})}
|
||||||
|
|
||||||
|
if _needs_full_project_snapshot(user_message, floating=False):
|
||||||
|
resolved_project_id = await _resolve_project_id_from_message(user_message)
|
||||||
|
if resolved_project_id:
|
||||||
|
base_context["resolved_project_id"] = resolved_project_id
|
||||||
|
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200])
|
||||||
|
plan = _build_full_project_snapshot_plan(user_message)
|
||||||
|
else:
|
||||||
|
plan = await _plan_with_llm(user_message, context, floating=False)
|
||||||
|
|
||||||
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"plan": [task.model_dump() for task in plan.tasks],
|
"plan": [task.model_dump() for task in plan.tasks],
|
||||||
"memory_context": new_memory
|
"memory_context": new_memory,
|
||||||
|
"context": base_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -470,6 +876,9 @@ FLOATING_GRAPH = _build_graph(floating=True)
|
|||||||
|
|
||||||
|
|
||||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||||
|
if HOME_SINGLE_AGENT_TEST_MODE:
|
||||||
|
return await _run_home_single_agent(user_id, message, context)
|
||||||
|
|
||||||
state = await HOME_GRAPH.ainvoke(
|
state = await HOME_GRAPH.ainvoke(
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -505,6 +914,11 @@ async def run_home_stream(
|
|||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
if HOME_SINGLE_AGENT_TEST_MODE:
|
||||||
|
async for event in _run_home_single_agent_stream(user_id, message, context):
|
||||||
|
yield event
|
||||||
|
return
|
||||||
|
|
||||||
state_input = {
|
state_input = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"user_message": message,
|
"user_message": message,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +33,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
|||||||
Reference in New Issue
Block a user