Files
api/app/core/deep_agent.py

332 lines
11 KiB
Python

"""Single-agent runners for home and floating chat contexts."""
from __future__ import annotations
import json
import logging
import re
from collections.abc import AsyncGenerator
from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
logger = logging.getLogger(__name__)
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
"Always use tools for factual data retrieval before answering. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
)
_FLOATING_SINGLE_AGENT_SYSTEM = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. "
"Stay focused on the floating scope in context.scope and answer concisely. "
"Always use tools for factual data retrieval before answering. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [token for token in tokens if len(token) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
def _needs_project_resolution(message: str) -> bool:
lowered = message.lower()
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
prepared = dict(context)
if _needs_project_resolution(message):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
return prepared
def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
scope = context.get("scope") if isinstance(context, dict) else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
if scope_type in {"task", "tasks"}:
return "tasks"
if scope_type in {"project", "projects"}:
return "projects"
if scope_type in {"note", "notes"}:
return "notes"
if scope_type in {"timeline", "timelines"}:
return "timelines"
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timelines"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "notes"
if any(keyword in lowered for keyword in ["project", "progetto", "client"]):
return "projects"
return "tasks"
async def _run_single_agent(
*,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
) -> str:
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
)
),
]
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
finally:
clear_tool_result_collector()
async def _run_single_agent_stream(
*,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
) -> AsyncGenerator[tuple[str, Any], None]:
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
)
),
]
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
return
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
finally:
clear_tool_result_collector()
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
prepared_context = await _prepare_context(message, context)
return await _run_single_agent(
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
domain = _infer_floating_domain(message, context)
prepared_context = await _prepare_context(message, context)
response = await _run_single_agent(
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
)
return response, domain
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
async for event in _run_single_agent_stream(
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
):
yield event
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
domain = _infer_floating_domain(message, context)
yield "floating_domain", domain
prepared_context = await _prepare_context(message, context)
async for event in _run_single_agent_stream(
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
):
yield event
async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)