620 lines
23 KiB
Python
620 lines
23 KiB
Python
"""Single-agent runners for home and floating chat contexts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from datetime import date
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any, Literal
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
from langchain_core.tools import tool
|
|
|
|
from app.agents.note_agent import NOTE_TOOLS
|
|
from app.agents.project_agent import PROJECT_TOOLS
|
|
from app.agents.task_agent import TASK_TOOLS
|
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
|
from app.core.llm import get_llm
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
|
from app.db import async_session
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
|
|
|
|
_HOME_SINGLE_AGENT_SYSTEM = (
|
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
"Always use tools for factual data retrieval before answering. "
|
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
|
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
|
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
|
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
|
)
|
|
|
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
|
"Always use tools for factual data retrieval before answering. "
|
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
|
|
)
|
|
|
|
|
|
def _as_text(content: Any) -> str:
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts)
|
|
return str(content)
|
|
|
|
|
|
def _candidate_tokens(message: str) -> list[str]:
|
|
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
|
return [token for token in tokens if len(token) >= 3]
|
|
|
|
|
|
async def _resolve_project_id_from_message(message: str) -> str | None:
|
|
"""Resolve likely project UUID from user message using client project list."""
|
|
try:
|
|
result = await execute_on_client(action="select", table="projects")
|
|
except Exception as exc:
|
|
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
|
return None
|
|
|
|
rows = result.get("rows", [])
|
|
if not isinstance(rows, list) or not rows:
|
|
return None
|
|
|
|
tokens = _candidate_tokens(message)
|
|
scored: list[tuple[int, dict[str, Any]]] = []
|
|
for row in rows:
|
|
if not isinstance(row, dict):
|
|
continue
|
|
name = str(row.get("name", "")).lower()
|
|
score = sum(1 for token in tokens if token in name)
|
|
if score > 0:
|
|
scored.append((score, row))
|
|
|
|
if not scored:
|
|
return None
|
|
|
|
scored.sort(key=lambda item: item[0], reverse=True)
|
|
top_score = scored[0][0]
|
|
top_rows = [row for score, row in scored if score == top_score]
|
|
if len(top_rows) != 1:
|
|
return None
|
|
|
|
project_id = top_rows[0].get("id")
|
|
return project_id if isinstance(project_id, str) else None
|
|
|
|
|
|
def _needs_project_resolution(message: str) -> bool:
|
|
lowered = message.lower()
|
|
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
|
|
|
|
|
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
|
prepared = dict(context)
|
|
if _needs_project_resolution(message):
|
|
resolved_project_id = await _resolve_project_id_from_message(message)
|
|
if resolved_project_id:
|
|
prepared["resolved_project_id"] = resolved_project_id
|
|
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
|
return prepared
|
|
|
|
|
|
def _all_tools() -> list[Any]:
|
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
|
|
|
|
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
|
debug = context.get("_debug")
|
|
if isinstance(debug, dict):
|
|
request_id = debug.get("request_id")
|
|
if isinstance(request_id, str) and request_id:
|
|
return request_id
|
|
return None
|
|
|
|
|
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
|
sanitized = dict(context)
|
|
sanitized.pop("_debug", None)
|
|
return sanitized
|
|
|
|
|
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
|
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
|
|
|
|
|
def _is_upcoming_timeline_query(message: str) -> bool:
|
|
lowered = message.lower()
|
|
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
|
has_timeline_topic = any(
|
|
token in lowered
|
|
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
|
)
|
|
return has_upcoming and has_timeline_topic
|
|
|
|
|
|
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
|
match = _TIMELINE_DMY_RE.search(dmy)
|
|
if not match:
|
|
return True
|
|
try:
|
|
parsed = date(
|
|
int(match.group("y")),
|
|
int(match.group("m")),
|
|
int(match.group("d")),
|
|
)
|
|
except ValueError:
|
|
return True
|
|
|
|
today = date.today()
|
|
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
|
|
|
|
|
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
|
if not text:
|
|
return text
|
|
|
|
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
|
output_lines: list[str] = []
|
|
|
|
for line in text.splitlines():
|
|
matches = list(_TAG_LINE_RE.finditer(line))
|
|
if not matches:
|
|
output_lines.append(line)
|
|
continue
|
|
|
|
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
|
if not had_non_tag_text and len(matches) == 1:
|
|
tag_text = matches[0].group(0)
|
|
if (
|
|
upcoming_timeline_only
|
|
and "<timeline>" in tag_text
|
|
and not _timeline_date_in_current_month_or_future(line)
|
|
):
|
|
continue
|
|
output_lines.append(tag_text)
|
|
continue
|
|
|
|
for match in matches:
|
|
tag_text = match.group(0)
|
|
if (
|
|
upcoming_timeline_only
|
|
and "<timeline>" in tag_text
|
|
and not _timeline_date_in_current_month_or_future(line)
|
|
):
|
|
continue
|
|
output_lines.append(tag_text)
|
|
|
|
return "\n".join(output_lines)
|
|
|
|
|
|
def _normalize_memory_label(path_or_label: str) -> str:
|
|
value = path_or_label.strip()
|
|
if value.startswith("/memories/"):
|
|
value = value[len("/memories/"):]
|
|
value = value.strip("/")
|
|
return value
|
|
|
|
|
|
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|
@tool
|
|
async def memory_list_blocks() -> str:
|
|
"""List all core memory blocks currently stored for the user."""
|
|
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
blocks = await memory.list_core_blocks(user_id)
|
|
if not blocks:
|
|
return "No memory blocks found."
|
|
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
|
return "Memory blocks:\n" + "\n".join(lines)
|
|
|
|
@tool
|
|
async def memory_get(path_or_label: str) -> str:
|
|
"""Get one memory block by label or /memories/<label> path."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
value = await memory.get_core_block(user_id, label)
|
|
if value is None:
|
|
return f"Memory block '{label}' not found."
|
|
return f"Memory block '{label}':\n{value}"
|
|
|
|
@tool
|
|
async def memory_create(path_or_label: str, value: str) -> str:
|
|
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
|
return f"Memory block '{label}' saved."
|
|
|
|
@tool
|
|
async def memory_append(path_or_label: str, content: str) -> str:
|
|
"""Append content to a memory block, creating it if missing."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.append_core(user_id, label, content)
|
|
return f"Memory block '{label}' appended."
|
|
|
|
@tool
|
|
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
|
"""Replace one exact string in a memory block."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
|
if not changed:
|
|
return f"No replacement made in '{label}' (old string not found)."
|
|
return f"Memory block '{label}' updated."
|
|
|
|
@tool
|
|
async def memory_delete(path_or_label: str) -> str:
|
|
"""Delete a memory block by label or /memories/<label> path."""
|
|
label = _normalize_memory_label(path_or_label)
|
|
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
if not label:
|
|
return "Invalid memory label."
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
deleted = await memory.delete_core(user_id, label)
|
|
if not deleted:
|
|
return f"Memory block '{label}' not found."
|
|
return f"Memory block '{label}' deleted."
|
|
|
|
@tool
|
|
async def archival_memory_insert(content: str) -> str:
|
|
"""Insert a long-term archival memory entry."""
|
|
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.insert_archival(user_id, content, source="assistant")
|
|
return "Archival memory saved."
|
|
|
|
@tool
|
|
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
|
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
|
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
results = await memory.search_archival(user_id, query, top_k=top_k)
|
|
if not results:
|
|
return "No archival memory results found."
|
|
lines = [f"- {item}" for item in results]
|
|
return "Archival memory results:\n" + "\n".join(lines)
|
|
|
|
@tool
|
|
async def conversation_search(query: str, top_k: int = 5) -> str:
|
|
"""Search recall memory from prior episodic conversation summaries."""
|
|
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
results = await memory.search_recall(user_id, query, top_k=top_k)
|
|
if not results:
|
|
return "No recall memory results found."
|
|
lines = [f"- {item}" for item in results]
|
|
return "Recall memory results:\n" + "\n".join(lines)
|
|
|
|
return [
|
|
memory_list_blocks,
|
|
memory_get,
|
|
memory_create,
|
|
memory_append,
|
|
memory_replace,
|
|
memory_delete,
|
|
archival_memory_insert,
|
|
archival_memory_search,
|
|
conversation_search,
|
|
]
|
|
|
|
|
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
|
|
|
|
|
def _infer_floating_domain(message: str, context: dict[str, Any]) -> FloatingDomain:
|
|
scope = context.get("scope") if isinstance(context, dict) else None
|
|
if isinstance(scope, dict):
|
|
scope_type = str(scope.get("type") or "").strip().lower()
|
|
if scope_type in {"task", "tasks"}:
|
|
return "tasks"
|
|
if scope_type in {"project", "projects"}:
|
|
return "projects"
|
|
if scope_type in {"note", "notes"}:
|
|
return "notes"
|
|
if scope_type in {"timeline", "timelines"}:
|
|
return "timelines"
|
|
|
|
lowered = message.lower()
|
|
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
|
return "timelines"
|
|
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
|
return "notes"
|
|
if any(keyword in lowered for keyword in ["project", "progetto", "client"]):
|
|
return "projects"
|
|
return "tasks"
|
|
|
|
|
|
async def _run_single_agent(
|
|
*,
|
|
user_id: str,
|
|
system_prompt: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
max_steps: int = 6,
|
|
) -> str:
|
|
trace_id = _trace_id_from_context(context)
|
|
llm = get_llm()
|
|
tools = _all_tools_for_user(user_id, trace_id)
|
|
model_context = _context_for_model(context)
|
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"User message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
tool_calls_count = 0
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
final_text = _as_text(response.content)
|
|
logger.info(
|
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
len(final_text),
|
|
)
|
|
return final_text
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
for call in response.tool_calls:
|
|
tool_calls_count += 1
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
call_id,
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
call_id,
|
|
call_name,
|
|
str(tool_output)[:1200],
|
|
)
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
final = await llm.ainvoke(messages)
|
|
final_text = _as_text(final.content)
|
|
logger.info(
|
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
len(final_text),
|
|
)
|
|
return final_text
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
async def _run_single_agent_stream(
|
|
*,
|
|
user_id: str,
|
|
system_prompt: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
max_steps: int = 6,
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
trace_id = _trace_id_from_context(context)
|
|
llm = get_llm()
|
|
tools = _all_tools_for_user(user_id, trace_id)
|
|
model_context = _context_for_model(context)
|
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(
|
|
content=(
|
|
f"User message:\n{message}\n\n"
|
|
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
)
|
|
),
|
|
]
|
|
|
|
tool_calls_count = 0
|
|
streamed_chars = 0
|
|
collected: list[dict[str, Any]] = []
|
|
set_tool_result_collector(collected)
|
|
try:
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if token:
|
|
streamed_chars += len(token)
|
|
yield "token", token
|
|
logger.info(
|
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
streamed_chars,
|
|
)
|
|
return
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
for call in response.tool_calls:
|
|
tool_calls_count += 1
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
call_id,
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
call_id,
|
|
call_name,
|
|
str(tool_output)[:1200],
|
|
)
|
|
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
async for chunk in llm.astream(messages):
|
|
token = _as_text(getattr(chunk, "content", ""))
|
|
if token:
|
|
streamed_chars += len(token)
|
|
yield "token", token
|
|
logger.info(
|
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
|
trace_id or "-",
|
|
user_id,
|
|
tool_calls_count,
|
|
streamed_chars,
|
|
)
|
|
finally:
|
|
clear_tool_result_collector()
|
|
|
|
|
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|
prepared_context = await _prepare_context(message, context)
|
|
response = await _run_single_agent(
|
|
user_id=user_id,
|
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
)
|
|
return _normalize_tagged_list_lines(response, message)
|
|
|
|
|
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
|
|
domain = _infer_floating_domain(message, context)
|
|
prepared_context = await _prepare_context(message, context)
|
|
response = await _run_single_agent(
|
|
user_id=user_id,
|
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
)
|
|
return response, domain
|
|
|
|
|
|
async def run_home_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
prepared_context = await _prepare_context(message, context)
|
|
text_chunks: list[str] = []
|
|
async for event in _run_single_agent_stream(
|
|
user_id=user_id,
|
|
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
):
|
|
event_type, data = event
|
|
if event_type != "token":
|
|
yield event
|
|
continue
|
|
text_chunks.append(str(data or ""))
|
|
|
|
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
|
if normalized:
|
|
yield "token", normalized
|
|
|
|
|
|
async def run_floating_stream(
|
|
user_id: str,
|
|
message: str,
|
|
context: dict[str, Any],
|
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
domain = _infer_floating_domain(message, context)
|
|
yield "floating_domain", domain
|
|
|
|
prepared_context = await _prepare_context(message, context)
|
|
async for event in _run_single_agent_stream(
|
|
user_id=user_id,
|
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
|
message=message,
|
|
context=prepared_context,
|
|
):
|
|
yield event
|
|
|
|
|
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.update_core(user_id, key, value)
|