332 lines
11 KiB
Python
332 lines
11 KiB
Python
"""Single-agent runners for home and floating chat contexts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any, Literal
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
|
|
from app.agents.note_agent import NOTE_TOOLS
|
|
from app.agents.project_agent import PROJECT_TOOLS
|
|
from app.agents.task_agent import TASK_TOOLS
|
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
|
from app.core.llm import get_llm
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
|
from app.db import async_session
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
|
|
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
|
"Always use tools for factual data retrieval before answering. "
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
|
)
|
|
|
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. "
|
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
|
"Always use tools for factual data retrieval before answering. "
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
|
)
|
|
|
|
|
|
def _as_text(content: Any) -> str:
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts)
|
|
return str(content)
|
|
|
|
|
|
def _candidate_tokens(message: str) -> list[str]:
|
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
|
return [token for token in tokens if len(token) >= 3]
|
|
|
|
|
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
|
"""Resolve likely project UUID from user message using client project list."""
|
|
try:
|
|
result = await execute_on_client(action="select", table="projects")
|
|
except Exception as exc:
|
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
|
return None
|
|
|
|
rows = result.get("rows", [])
|
|
if not isinstance(rows, list) or not rows:
|
|
return None
|
|
|
|
tokens = _candidate_tokens(message)
|
|
scored: list[tuple[int, dict[str, Any]]] = []
|
|
for row in rows:
|
|
if not isinstance(row, dict):
|
|
continue
|
|
name = str(row.get("name", "")).lower()
|
|
score = sum(1 for token in tokens if token in name)
|
|
if score > 0:
|
|
scored.append((score, row))
|
|
|
|
if not scored:
|
|
return None
|
|
|
|
scored.sort(key=lambda item: item[0], reverse=True)
|
|
top_score = scored[0][0]
|
|
top_rows = [row for score, row in scored if score == top_score]
|
|
if len(top_rows) != 1:
|
|
return None
|
|
|
|
project_id = top_rows[0].get("id")
|
|
return project_id if isinstance(project_id, str) else None
|
|
|
|
|
|
def _needs_project_resolution(message: str) -> bool:
|
|
lowered = message.lower()
|
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
|
|
|
|
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
|
prepared = dict(context)
|
|
if _needs_project_resolution(message):
|
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
|
if resolved_project_id:
|
|
prepared["resolved_project_id"] = resolved_project_id
|
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
|
return prepared
|
|
|
|
|
|
def _all_tools() -> list[Any]:
|
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
|
|
|
|
|
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
|
|
scope = context.get("scope") if isinstance(context, dict) else None
|
|
if isinstance(scope, dict):
|
|
scope_type = str(scope.get("type") or "").strip().lower()
|
|
if scope_type in {"task", "tasks"}:
|
|
return "tasks"
|
|
if scope_type in {"project", "projects"}:
|
|
return "projects"
|
|
if scope_type in {"note", "notes"}:
|
|
return "notes"
|
|
if scope_type in {"timeline", "timelines"}:
|
|
return "timelines"
|
|
|
|
lowered = message.lower()
|
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
|
return "timelines"
|
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
|
return "notes"
|
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]):
|
|
return "projects"
|
|
return "tasks"
|
|
|
|
|
|
async def _run_single_agent(
|
|
*,
|
|
system_prompt: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
max_steps: int = 6,
|
|
) -> str:
|
|
llm = get_llm()
|
|
tools = _all_tools()
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"User message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
return _as_text(response.content)
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
for call in response.tool_calls:
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
call_id,
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
call_id,
|
|
call_name,
|
|
str(tool_output)[:1200],
|
|
)
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
final = await llm.ainvoke(messages)
|
|
return _as_text(final.content)
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
async def _run_single_agent_stream(
|
|
*,
|
|
system_prompt: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
max_steps: int = 6,
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
llm = get_llm()
|
|
tools = _all_tools()
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"User message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if token:
|
|
yield "token", token
|
|
return
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
for call in response.tool_calls:
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
call_id,
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
call_id,
|
|
call_name,
|
|
str(tool_output)[:1200],
|
|
)
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if token:
|
|
yield "token", token
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|
prepared_context = await _prepare_context(message, context)
|
|
return await _run_single_agent(
|
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
)
|
|
|
|
|
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
|
domain = _infer_floating_domain(message, context)
|
|
prepared_context = await _prepare_context(message, context)
|
|
response = await _run_single_agent(
|
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
)
|
|
return response, domain
|
|
|
|
|
|
async def run_home_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
prepared_context = await _prepare_context(message, context)
|
|
async for event in _run_single_agent_stream(
|
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
):
|
|
yield event
|
|
|
|
|
|
async def run_floating_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
domain = _infer_floating_domain(message, context)
|
|
yield "floating_domain", domain
|
|
|
|
prepared_context = await _prepare_context(message, context)
|
|
async for event in _run_single_agent_stream(
|
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
):
|
|
yield event
|
|
|
|
|
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.update_core(user_id, key, value)
|