"""Single-agent runners for home and floating chat contexts.""" from __future__ import annotations import json import logging import re from collections.abc import AsyncGenerator from typing import Any, Literal from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from app.agents.note_agent import NOTE_TOOLS from app.agents.project_agent import PROJECT_TOOLS from app.agents.task_agent import TASK_TOOLS from app.agents.timeline_agent import TIMELINE_TOOLS from app.core.llm import get_llm from app.core.memory_middleware import MemoryMiddleware from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector from app.db import async_session logger = logging.getLogger(__name__) FloatingDomain = Literal["tasks", "projects", "notes", "timelines"] _HOME_SINGLE_AGENT_SYSTEM = ( "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. " "Always use tools for factual data retrieval before answering. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " "Return markdown and embed inline tags when relevant: [ids], [ids], " "[ids], [ids], {json}." ) _FLOATING_SINGLE_AGENT_SYSTEM = ( "You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. " "Stay focused on the floating scope in context.scope and answer concisely. " "Always use tools for factual data retrieval before answering. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " "Return markdown and embed inline tags when relevant: [ids], [ids], " "[ids], [ids], {json}." ) def _as_text(content: Any) -> str: if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict): text = item.get("text") if isinstance(text, str): parts.append(text) return "".join(parts) return str(content) def _candidate_tokens(message: str) -> list[str]: tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower()) return [token for token in tokens if len(token) >= 3] async def _resolve_project_id_from_message(message: str) -> str | None: """Resolve likely project UUID from user message using client project list.""" try: result = await execute_on_client(action="select", table="projects") except Exception as exc: logger.warning("deep_agent: project resolve select failed: %s", exc) return None rows = result.get("rows", []) if not isinstance(rows, list) or not rows: return None tokens = _candidate_tokens(message) scored: list[tuple[int, dict[str, Any]]] = [] for row in rows: if not isinstance(row, dict): continue name = str(row.get("name", "")).lower() score = sum(1 for token in tokens if token in name) if score > 0: scored.append((score, row)) if not scored: return None scored.sort(key=lambda item: item[0], reverse=True) top_score = scored[0][0] top_rows = [row for score, row in scored if score == top_score] if len(top_rows) != 1: return None project_id = top_rows[0].get("id") return project_id if isinstance(project_id, str) else None def _needs_project_resolution(message: str) -> bool: lowered = message.lower() return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"]) async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]: prepared = dict(context) if _needs_project_resolution(message): resolved_project_id = await _resolve_project_id_from_message(message) if resolved_project_id: prepared["resolved_project_id"] = resolved_project_id logger.info("deep_agent: resolved_project_id=%s", resolved_project_id) return prepared def _all_tools() -> list[Any]: return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS] def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain: scope = context.get("scope") if isinstance(context, dict) else None if isinstance(scope, dict): scope_type = str(scope.get("type") or "").strip().lower() if scope_type in {"task", "tasks"}: return "tasks" if scope_type in {"project", "projects"}: return "projects" if scope_type in {"note", "notes"}: return "notes" if scope_type in {"timeline", "timelines"}: return "timelines" lowered = message.lower() if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]): return "timelines" if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]): return "notes" if any(keyword in lowered for keyword in ["project", "progetto", "client"]): return "projects" return "tasks" async def _run_single_agent( *, system_prompt: str, message: str, context: dict[str, Any], max_steps: int = 6, ) -> str: llm = get_llm() tools = _all_tools() llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ SystemMessage(content=system_prompt), HumanMessage( content=( f"User message:\n{message}\n\n" f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}" ) ), ] collected: list[dict[str, Any]] = [] set_tool_result_collector(collected) try: for _ in range(max_steps): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: return _as_text(response.content) tool_map = {tool_def.name: tool_def for tool_def in tools} for call in response.tool_calls: call_id = str(call.get("id", "")) call_name = str(call.get("name", "")) call_args = call.get("args", {}) logger.info( "deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s", call_id, call_name, json.dumps(call_args, ensure_ascii=True)[:800], ) tool_fn = tool_map.get(call_name) if tool_fn is None: tool_output = f"Unknown tool: {call_name}" else: tool_output = await tool_fn.ainvoke(call_args) logger.info( "deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s", call_id, call_name, str(tool_output)[:1200], ) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) final = await llm.ainvoke(messages) return _as_text(final.content) finally: clear_tool_result_collector() async def _run_single_agent_stream( *, system_prompt: str, message: str, context: dict[str, Any], max_steps: int = 6, ) -> AsyncGenerator[tuple[str, Any], None]: llm = get_llm() tools = _all_tools() llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ SystemMessage(content=system_prompt), HumanMessage( content=( f"User message:\n{message}\n\n" f"Context:\n{json.dumps({'context': context}, ensure_ascii=True)[:3500]}" ) ), ] collected: list[dict[str, Any]] = [] set_tool_result_collector(collected) try: for _ in range(max_steps): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: async for chunk in llm.astream(messages): token = _as_text(getattr(chunk, "content", "")) if token: yield "token", token return tool_map = {tool_def.name: tool_def for tool_def in tools} for call in response.tool_calls: call_id = str(call.get("id", "")) call_name = str(call.get("name", "")) call_args = call.get("args", {}) logger.info( "deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s", call_id, call_name, json.dumps(call_args, ensure_ascii=True)[:800], ) tool_fn = tool_map.get(call_name) if tool_fn is None: tool_output = f"Unknown tool: {call_name}" else: tool_output = await tool_fn.ainvoke(call_args) logger.info( "deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s", call_id, call_name, str(tool_output)[:1200], ) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) async for chunk in llm.astream(messages): token = _as_text(getattr(chunk, "content", "")) if token: yield "token", token finally: clear_tool_result_collector() async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: prepared_context = await _prepare_context(message, context) return await _run_single_agent( system_prompt=_HOME_SINGLE_AGENT_SYSTEM, message=message, context=prepared_context, ) async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]: domain = _infer_floating_domain(message, context) prepared_context = await _prepare_context(message, context) response = await _run_single_agent( system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, message=message, context=prepared_context, ) return response, domain async def run_home_stream( user_id: str, message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: prepared_context = await _prepare_context(message, context) async for event in _run_single_agent_stream( system_prompt=_HOME_SINGLE_AGENT_SYSTEM, message=message, context=prepared_context, ): yield event async def run_floating_stream( user_id: str, message: str, context: dict[str, Any], ) -> AsyncGenerator[tuple[str, Any], None]: domain = _infer_floating_domain(message, context) yield "floating_domain", domain prepared_context = await _prepare_context(message, context) async for event in _run_single_agent_stream( system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, message=message, context=prepared_context, ): yield event async def update_core_memory(user_id: str, key: str, value: str) -> None: """Compatibility helper kept for callers that expect explicit memory update API.""" async with async_session() as db: memory = MemoryMiddleware(db) await memory.update_core(user_id, key, value)