Move duplicated files from chat + batch-agent into shared/: - shared/ws_context.py — Redis-based tool call round-trip - shared/llm.py — LiteLLM factory (get_llm, embed) - shared/agents/ — 4 domain agents (task, note, project, timeline) Update all service imports to use shared.* instead of app.*. Delete 12 duplicated files across both services.
884 lines
33 KiB
Python
884 lines
33 KiB
Python
"""Single-agent runners for home and floating chat contexts.
|
|
|
|
Adapted from app/core/deep_agent.py for the Chat Service.
|
|
Import paths changed to use local app modules and shared/.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from datetime import date
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any, Literal
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
from langchain_core.tools import tool
|
|
|
|
from shared.agents.note_agent import NOTE_TOOLS
|
|
from shared.agents.project_agent import PROJECT_TOOLS
|
|
from shared.agents.task_agent import TASK_TOOLS
|
|
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
|
from shared.llm import get_llm
|
|
from app.memory_middleware import MemoryMiddleware
|
|
from shared.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
|
from app import tracing
|
|
from shared.db import async_session
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
|
|
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
"Always use tools for factual data retrieval before answering. "
|
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
|
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
|
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
|
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
|
)
|
|
|
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
|
"Always use tools for factual data retrieval before answering. "
|
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
)
|
|
|
|
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
|
"You are a strict domain classifier for websocket floating requests. "
|
|
"Return ONLY a JSON object with keys: type, id, section. "
|
|
"Allowed type values: task, timeline, project, node. "
|
|
"Allowed section values: task, timeline, note, or null. "
|
|
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
|
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
|
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
|
"If id is unknown, use null. "
|
|
"No markdown, no prose, JSON only."
|
|
)
|
|
|
|
|
|
def _as_text(content: Any) -> str:
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts)
|
|
return str(content)
|
|
|
|
|
|
def _candidate_tokens(message: str) -> list[str]:
|
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
|
return [token for token in tokens if len(token) >= 3]
|
|
|
|
|
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
|
"""Resolve likely project UUID from user message using client project list."""
|
|
try:
|
|
result = await execute_on_client(action="select", table="projects")
|
|
except Exception as exc:
|
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
|
return None
|
|
|
|
rows = result.get("rows", [])
|
|
if not isinstance(rows, list) or not rows:
|
|
return None
|
|
|
|
tokens = _candidate_tokens(message)
|
|
scored: list[tuple[int, dict[str, Any]]] = []
|
|
for row in rows:
|
|
if not isinstance(row, dict):
|
|
continue
|
|
name = str(row.get("name", "")).lower()
|
|
score = sum(1 for token in tokens if token in name)
|
|
if score > 0:
|
|
scored.append((score, row))
|
|
|
|
if not scored:
|
|
return None
|
|
|
|
scored.sort(key=lambda item: item[0], reverse=True)
|
|
top_score = scored[0][0]
|
|
top_rows = [row for score, row in scored if score == top_score]
|
|
if len(top_rows) != 1:
|
|
return None
|
|
|
|
project_id = top_rows[0].get("id")
|
|
return project_id if isinstance(project_id, str) else None
|
|
|
|
|
|
def _needs_project_resolution(message: str) -> bool:
|
|
lowered = message.lower()
|
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
|
|
|
|
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
|
prepared = dict(context)
|
|
if _needs_project_resolution(message):
|
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
|
if resolved_project_id:
|
|
prepared["resolved_project_id"] = resolved_project_id
|
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
|
return prepared
|
|
|
|
|
|
def _all_tools() -> list[Any]:
|
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
|
|
|
|
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
|
debug = context.get("_debug")
|
|
if isinstance(debug, dict):
|
|
request_id = debug.get("request_id")
|
|
if isinstance(request_id, str) and request_id:
|
|
return request_id
|
|
return None
|
|
|
|
|
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
|
sanitized = dict(context)
|
|
sanitized.pop("_debug", None)
|
|
return sanitized
|
|
|
|
|
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
|
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
|
|
|
|
|
def _is_upcoming_timeline_query(message: str) -> bool:
|
|
lowered = message.lower()
|
|
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
|
has_timeline_topic = any(
|
|
token in lowered
|
|
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
|
)
|
|
return has_upcoming and has_timeline_topic
|
|
|
|
|
|
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
|
match = _TIMELINE_DMY_RE.search(dmy)
|
|
if not match:
|
|
return True
|
|
try:
|
|
parsed = date(
|
|
int(match.group("y")),
|
|
int(match.group("m")),
|
|
int(match.group("d")),
|
|
)
|
|
except ValueError:
|
|
return True
|
|
|
|
today = date.today()
|
|
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
|
|
|
|
|
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
|
if not text:
|
|
return text
|
|
|
|
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
|
output_lines: list[str] = []
|
|
|
|
for line in text.splitlines():
|
|
matches = list(_TAG_LINE_RE.finditer(line))
|
|
if not matches:
|
|
output_lines.append(line)
|
|
continue
|
|
|
|
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
|
if not had_non_tag_text and len(matches) == 1:
|
|
tag_text = matches[0].group(0)
|
|
if (
|
|
upcoming_timeline_only
|
|
and "<timeline>" in tag_text
|
|
and not _timeline_date_in_current_month_or_future(line)
|
|
):
|
|
continue
|
|
output_lines.append(tag_text)
|
|
continue
|
|
|
|
for match in matches:
|
|
tag_text = match.group(0)
|
|
if (
|
|
upcoming_timeline_only
|
|
and "<timeline>" in tag_text
|
|
and not _timeline_date_in_current_month_or_future(line)
|
|
):
|
|
continue
|
|
output_lines.append(tag_text)
|
|
|
|
return "\n".join(output_lines)
|
|
|
|
|
|
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
|
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
|
_FLOATING_EMPTY_FALLBACK = "No results found."
|
|
|
|
|
|
def _strip_floating_markup_fragment(text: str) -> str:
|
|
if not text:
|
|
return text
|
|
cleaned = _GENERIC_TAG_RE.sub("", text)
|
|
return _BRACKETED_ID_RE.sub("", cleaned)
|
|
|
|
|
|
def _strip_floating_markup(text: str) -> str:
|
|
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
|
if not text:
|
|
return text
|
|
|
|
cleaned = _strip_floating_markup_fragment(text)
|
|
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
|
return "\n".join(line for line in lines if line)
|
|
|
|
|
|
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
|
fallback = _strip_floating_markup_fragment(raw_text or "")
|
|
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
|
return fallback or _FLOATING_EMPTY_FALLBACK
|
|
|
|
|
|
class _FloatingStreamSanitizer:
|
|
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
|
|
|
def __init__(self) -> None:
|
|
self._pending = ""
|
|
|
|
@staticmethod
|
|
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
|
boundary = len(text)
|
|
|
|
last_lt = text.rfind("<")
|
|
if last_lt != -1 and ">" not in text[last_lt:]:
|
|
boundary = min(boundary, last_lt)
|
|
|
|
last_lb = text.rfind("[")
|
|
if last_lb != -1 and "]" not in text[last_lb:]:
|
|
boundary = min(boundary, last_lb)
|
|
|
|
if boundary == len(text):
|
|
return text, ""
|
|
return text[:boundary], text[boundary:]
|
|
|
|
def feed(self, chunk: str) -> str:
|
|
combined = f"{self._pending}{chunk}"
|
|
safe_text, self._pending = self._split_safe_boundary(combined)
|
|
return _strip_floating_markup_fragment(safe_text)
|
|
|
|
def finalize(self) -> str:
|
|
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
|
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
|
self._pending = ""
|
|
return _strip_floating_markup_fragment(tail)
|
|
|
|
|
|
def _normalize_memory_label(path_or_label: str) -> str:
|
|
value = path_or_label.strip()
|
|
if value.startswith("/memories/"):
|
|
value = value[len("/memories/"):]
|
|
value = value.strip("/")
|
|
return value
|
|
|
|
|
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|
@tool
|
|
async def memory_list_blocks() -> str:
|
|
"""List all core memory blocks currently stored for the user."""
|
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
blocks = await memory.list_core_blocks(user_id)
|
|
if not blocks:
|
|
return "No memory blocks found."
|
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
|
return "Memory blocks:\n" + "\n".join(lines)
|
|
|
|
@tool
|
|
async def memory_get(path_or_label: str) -> str:
|
|
"""Get one memory block by label or /memories/<label> path."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
value = await memory.get_core_block(user_id, label)
|
|
if value is None:
|
|
return f"Memory block '{label}' not found."
|
|
return f"Memory block '{label}':\n{value}"
|
|
|
|
@tool
|
|
async def memory_create(path_or_label: str, value: str) -> str:
|
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
|
return f"Memory block '{label}' saved."
|
|
|
|
@tool
|
|
async def memory_append(path_or_label: str, content: str) -> str:
|
|
"""Append content to a memory block, creating it if missing."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.append_core(user_id, label, content)
|
|
return f"Memory block '{label}' appended."
|
|
|
|
@tool
|
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
|
"""Replace one exact string in a memory block."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
|
if not changed:
|
|
return f"No replacement made in '{label}' (old string not found)."
|
|
return f"Memory block '{label}' updated."
|
|
|
|
@tool
|
|
async def memory_delete(path_or_label: str) -> str:
|
|
"""Delete a memory block by label or /memories/<label> path."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
deleted = await memory.delete_core(user_id, label)
|
|
if not deleted:
|
|
return f"Memory block '{label}' not found."
|
|
return f"Memory block '{label}' deleted."
|
|
|
|
@tool
|
|
async def archival_memory_insert(content: str) -> str:
|
|
"""Insert a long-term archival memory entry."""
|
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.insert_archival(user_id, content, source="assistant")
|
|
return "Archival memory saved."
|
|
|
|
@tool
|
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
|
if not results:
|
|
return "No archival memory results found."
|
|
lines = [f"- {item}" for item in results]
|
|
return "Archival memory results:\n" + "\n".join(lines)
|
|
|
|
@tool
|
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
|
"""Search recall memory from prior episodic conversation summaries."""
|
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
|
if not results:
|
|
return "No recall memory results found."
|
|
lines = [f"- {item}" for item in results]
|
|
return "Recall memory results:\n" + "\n".join(lines)
|
|
|
|
return [
|
|
memory_list_blocks,
|
|
memory_get,
|
|
memory_create,
|
|
memory_append,
|
|
memory_replace,
|
|
memory_delete,
|
|
archival_memory_insert,
|
|
archival_memory_search,
|
|
conversation_search,
|
|
]
|
|
|
|
|
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
|
|
|
|
|
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
|
lowered = message.lower()
|
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
|
return "timeline"
|
|
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
|
return "task"
|
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
|
return "note"
|
|
return None
|
|
|
|
|
|
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
|
type_raw = str(payload.get("type") or "").strip().lower()
|
|
domain_type: FloatingDomainType = "task"
|
|
if type_raw in {"task", "timeline", "project", "node"}:
|
|
domain_type = type_raw
|
|
|
|
id_value = payload.get("id")
|
|
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
|
if domain_type == "project" and not domain_id:
|
|
domain_id = fallback_id
|
|
|
|
section_raw = payload.get("section")
|
|
section: FloatingDomainSection | None = None
|
|
if isinstance(section_raw, str):
|
|
section_candidate = section_raw.strip().lower()
|
|
if section_candidate in {"task", "timeline", "note"}:
|
|
section = section_candidate
|
|
|
|
if domain_type != "project":
|
|
section = None
|
|
|
|
return {
|
|
"type": domain_type,
|
|
"id": domain_id,
|
|
"section": section,
|
|
}
|
|
|
|
|
|
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
|
raw = text.strip()
|
|
if not raw:
|
|
return None
|
|
try:
|
|
parsed = json.loads(raw)
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
if not match:
|
|
return None
|
|
try:
|
|
parsed = json.loads(match.group(0))
|
|
except json.JSONDecodeError:
|
|
return None
|
|
return parsed if isinstance(parsed, dict) else None
|
|
|
|
|
|
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
|
section = _detect_domain_section(message)
|
|
scope = context.get("scope") if isinstance(context, dict) else None
|
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
|
|
|
if isinstance(scope, dict):
|
|
scope_type = str(scope.get("type") or "").strip().lower()
|
|
scope_id = scope.get("id")
|
|
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
|
|
|
if scope_type in {"task", "tasks"}:
|
|
return {"type": "task", "id": scope_id_value, "section": None}
|
|
if scope_type in {"project", "projects"}:
|
|
project_scope_id = scope_id_value or project_id
|
|
return {
|
|
"type": "project",
|
|
"id": project_scope_id,
|
|
"section": section,
|
|
}
|
|
if scope_type in {"note", "notes"}:
|
|
return {
|
|
"type": "node",
|
|
"id": scope_id_value,
|
|
"section": None,
|
|
}
|
|
if scope_type in {"timeline", "timelines"}:
|
|
return {"type": "timeline", "id": scope_id_value, "section": None}
|
|
|
|
lowered = message.lower()
|
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
|
return {
|
|
"type": "project",
|
|
"id": project_id,
|
|
"section": section,
|
|
}
|
|
if section == "timeline":
|
|
return {"type": "timeline", "id": None, "section": None}
|
|
if section == "note":
|
|
return {"type": "node", "id": None, "section": None}
|
|
return {"type": "task", "id": None, "section": None}
|
|
|
|
|
|
async def _infer_floating_domain(
|
|
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
|
|
) -> dict[str, str | None]:
|
|
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
|
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
|
|
|
classifier_context = {
|
|
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
|
"resolved_project_id": project_id,
|
|
}
|
|
|
|
try:
|
|
classifier_prompt = _get_system_prompt(
|
|
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
|
|
)
|
|
callbacks = _build_callbacks(langfuse_handler)
|
|
llm = get_llm(callbacks=callbacks)
|
|
response = await llm.ainvoke(
|
|
[
|
|
SystemMessage(content=classifier_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"Message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
|
)
|
|
),
|
|
]
|
|
)
|
|
parsed = _parse_json_object(_as_text(response.content))
|
|
if parsed is not None:
|
|
domain = _normalize_domain_payload(parsed, project_id)
|
|
logger.info(
|
|
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
|
domain.get("type"),
|
|
domain.get("id"),
|
|
domain.get("section"),
|
|
)
|
|
return domain
|
|
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
|
except Exception as exc:
|
|
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
|
|
|
return _infer_floating_domain_rule_based(message, context)
|
|
|
|
|
|
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
|
|
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
|
|
managed = tracing.get_prompt(langfuse_name, fallback=None)
|
|
return managed if managed is not None else fallback
|
|
|
|
|
|
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
|
|
"""Return a callbacks list if a Langfuse handler is available."""
|
|
if langfuse_handler is None:
|
|
return None
|
|
return [langfuse_handler]
|
|
|
|
|
|
async def _run_single_agent(
|
|
*,
|
|
user_id: str,
|
|
system_prompt: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
max_steps: int = 6,
|
|
langfuse_handler: Any | None = None,
|
|
) -> str:
|
|
trace_id = _trace_id_from_context(context)
|
|
callbacks = _build_callbacks(langfuse_handler)
|
|
llm = get_llm(callbacks=callbacks)
|
|
tools = _all_tools_for_user(user_id, trace_id)
|
|
model_context = _context_for_model(context)
|
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"User message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
tool_calls_count = 0
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
final_text = _as_text(response.content)
|
|
logger.info(
|
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
len(final_text),
|
|
)
|
|
return final_text
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
for call in response.tool_calls:
|
|
tool_calls_count += 1
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
call_id,
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
call_id,
|
|
call_name,
|
|
str(tool_output)[:1200],
|
|
)
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
final = await llm.ainvoke(messages)
|
|
final_text = _as_text(final.content)
|
|
logger.info(
|
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
len(final_text),
|
|
)
|
|
return final_text
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
async def _run_single_agent_stream(
|
|
*,
|
|
user_id: str,
|
|
system_prompt: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
max_steps: int = 6,
|
|
langfuse_handler: Any | None = None,
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
trace_id = _trace_id_from_context(context)
|
|
callbacks = _build_callbacks(langfuse_handler)
|
|
llm = get_llm(callbacks=callbacks)
|
|
tools = _all_tools_for_user(user_id, trace_id)
|
|
model_context = _context_for_model(context)
|
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"User message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
tool_calls_count = 0
|
|
streamed_chars = 0
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
emitted_any = False
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if token:
|
|
streamed_chars += len(token)
|
|
emitted_any = True
|
|
yield "token", token
|
|
|
|
if not emitted_any:
|
|
fallback_text = _as_text(response.content)
|
|
if fallback_text:
|
|
streamed_chars += len(fallback_text)
|
|
yield "token", fallback_text
|
|
logger.info(
|
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
streamed_chars,
|
|
)
|
|
return
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
for call in response.tool_calls:
|
|
tool_calls_count += 1
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
call_id,
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
call_id,
|
|
call_name,
|
|
str(tool_output)[:1200],
|
|
)
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if token:
|
|
streamed_chars += len(token)
|
|
yield "token", token
|
|
logger.info(
|
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
streamed_chars,
|
|
)
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
|
|
prepared_context = await _prepare_context(message, context)
|
|
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
|
response = await _run_single_agent(
|
|
user_id=user_id,
|
|
system_prompt=system_prompt,
|
|
message=message,
|
|
context=prepared_context,
|
|
langfuse_handler=langfuse_handler,
|
|
)
|
|
return _normalize_tagged_list_lines(response, message)
|
|
|
|
|
|
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
|
|
prepared_context = await _prepare_context(message, context)
|
|
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
|
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
|
response = await _run_single_agent(
|
|
user_id=user_id,
|
|
system_prompt=system_prompt,
|
|
message=message,
|
|
context=prepared_context,
|
|
langfuse_handler=langfuse_handler,
|
|
)
|
|
sanitized = _strip_floating_markup(response)
|
|
if not sanitized and response:
|
|
sanitized = _fallback_from_raw_floating_text(response)
|
|
return sanitized, domain
|
|
|
|
|
|
async def run_home_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
*,
|
|
langfuse_handler: Any | None = None,
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
prepared_context = await _prepare_context(message, context)
|
|
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
|
text_chunks: list[str] = []
|
|
async for event in _run_single_agent_stream(
|
|
user_id=user_id,
|
|
system_prompt=system_prompt,
|
|
message=message,
|
|
context=prepared_context,
|
|
langfuse_handler=langfuse_handler,
|
|
):
|
|
event_type, data = event
|
|
if event_type != "token":
|
|
yield event
|
|
continue
|
|
text_chunks.append(str(data or ""))
|
|
|
|
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
|
if normalized:
|
|
yield "token", normalized
|
|
|
|
|
|
async def run_floating_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
*,
|
|
langfuse_handler: Any | None = None,
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
prepared_context = await _prepare_context(message, context)
|
|
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
|
yield "floating_domain", domain
|
|
|
|
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
|
sanitizer = _FloatingStreamSanitizer()
|
|
emitted_sanitized = False
|
|
raw_chunks: list[str] = []
|
|
async for event in _run_single_agent_stream(
|
|
user_id=user_id,
|
|
system_prompt=system_prompt,
|
|
message=message,
|
|
context=prepared_context,
|
|
langfuse_handler=langfuse_handler,
|
|
):
|
|
event_type, data = event
|
|
if event_type != "token":
|
|
yield event
|
|
continue
|
|
|
|
raw_chunk = str(data or "")
|
|
raw_chunks.append(raw_chunk)
|
|
sanitized_chunk = sanitizer.feed(raw_chunk)
|
|
if sanitized_chunk:
|
|
emitted_sanitized = True
|
|
yield "token", sanitized_chunk
|
|
|
|
tail = sanitizer.finalize()
|
|
if tail:
|
|
emitted_sanitized = True
|
|
yield "token", tail
|
|
|
|
if not emitted_sanitized and raw_chunks:
|
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
|
|
|
|
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.update_core(user_id, key, value)
|