diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index 1257e13..b1d2e6f 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -223,10 +223,11 @@ async def _handle_home_request( # ── Memory: enrich context before LLM call ──────────────────────── async with async_session() as db: memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context(user_id, message) + memory_context = await memory.enrich_context(user_id, message, trace_id=request_id) context: dict = { "conversation_history": frame.get("conversation_history", []), + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, **memory_context, } @@ -253,7 +254,7 @@ async def _handle_home_request( async with async_session() as db: memory = MemoryMiddleware(db) await memory.store_episode( - user_id, session_id, message, "".join(response_chunks) + user_id, session_id, message, "".join(response_chunks), trace_id=request_id ) @@ -271,9 +272,13 @@ async def _handle_floating_request( # ── Memory: enrich context before LLM call ──────────────────────── async with async_session() as db: memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context(user_id, message) + memory_context = await memory.enrich_context(user_id, message, trace_id=request_id) - context: dict = {"scope": scope, **memory_context} + context: dict = { + "scope": scope, + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + **memory_context, + } executor = await _make_ws_executor(websocket, user_id) set_client_executor(executor) @@ -297,7 +302,7 @@ async def _handle_floating_request( async with async_session() as db: memory = MemoryMiddleware(db) await memory.store_episode( - user_id, session_id, message, "".join(response_chunks) + user_id, session_id, message, "".join(response_chunks), trace_id=request_id ) diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index 22559a4..6f3fcd4 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator from typing import Any, Literal from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.tools import tool from app.agents.note_agent import NOTE_TOOLS from app.agents.project_agent import PROJECT_TOOLS @@ -24,17 +25,19 @@ logger = logging.getLogger(__name__) FloatingDomain = Literal["tasks", "projects", "notes", "timelines"] _HOME_SINGLE_AGENT_SYSTEM = ( - "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. " + "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " "Always use tools for factual data retrieval before answering. " + "When the user asks to remember, forget, or update what you know about them, use memory tools. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " "Return markdown and embed inline tags when relevant: [ids], [ids], " "[ids], [ids], {json}." ) _FLOATING_SINGLE_AGENT_SYSTEM = ( - "You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines. " + "You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " "Stay focused on the floating scope in context.scope and answer concisely. " "Always use tools for factual data retrieval before answering. " + "When the user asks to remember, forget, or update what you know about them, use memory tools. " "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " "Return markdown and embed inline tags when relevant: [ids], [ids], " "[ids], [ids], {json}." @@ -118,6 +121,158 @@ def _all_tools() -> list[Any]: return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS] +def _trace_id_from_context(context: dict[str, Any]) -> str | None: + debug = context.get("_debug") + if isinstance(debug, dict): + request_id = debug.get("request_id") + if isinstance(request_id, str) and request_id: + return request_id + return None + + +def _context_for_model(context: dict[str, Any]) -> dict[str, Any]: + sanitized = dict(context) + sanitized.pop("_debug", None) + return sanitized + + +def _normalize_memory_label(path_or_label: str) -> str: + value = path_or_label.strip() + if value.startswith("/memories/"): + value = value[len("/memories/"):] + value = value.strip("/") + return value + + +def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]: + @tool + async def memory_list_blocks() -> str: + """List all core memory blocks currently stored for the user.""" + logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id) + async with async_session() as db: + memory = MemoryMiddleware(db) + blocks = await memory.list_core_blocks(user_id) + if not blocks: + return "No memory blocks found." + lines = [f"- {b['label']}: {b['value']}" for b in blocks] + return "Memory blocks:\n" + "\n".join(lines) + + @tool + async def memory_get(path_or_label: str) -> str: + """Get one memory block by label or /memories/