97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
"""In-process TTL buffer for per-session LangChain message history.
|
|
|
|
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
|
|
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
|
|
conversation turns without it being lossy through the wire.
|
|
|
|
Single-process only. For multi-worker deployments, replace the _SessionBuffer
|
|
implementation with one backed by Redis (serialize LangChain messages to dicts via
|
|
message_to_dict / messages_from_dict from langchain_core.messages).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from threading import Lock
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
|
|
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
|
|
|
|
|
|
class _SessionBuffer:
|
|
def __init__(self) -> None:
|
|
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
|
|
self._lock = Lock()
|
|
|
|
def _evict_stale(self) -> None:
|
|
now = time.monotonic()
|
|
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
|
|
for k in stale:
|
|
del self._store[k]
|
|
|
|
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
|
|
key = (user_id, session_id)
|
|
with self._lock:
|
|
entry = self._store.get(key)
|
|
if entry is None:
|
|
return None
|
|
ts, msgs = entry
|
|
if time.monotonic() - ts > SESSION_TTL_SECONDS:
|
|
del self._store[key]
|
|
return None
|
|
self._store[key] = (time.monotonic(), msgs)
|
|
return list(msgs)
|
|
|
|
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
|
|
key = (user_id, session_id)
|
|
capped = messages[-MAX_MESSAGES_PER_SESSION:]
|
|
with self._lock:
|
|
self._evict_stale()
|
|
self._store[key] = (time.monotonic(), capped)
|
|
|
|
def clear(self, user_id: str, session_id: str) -> None:
|
|
with self._lock:
|
|
self._store.pop((user_id, session_id), None)
|
|
|
|
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
|
|
"""Append a synthetic system message to the buffer for the given session.
|
|
|
|
Creates the session slot if it does not yet exist. Used by the
|
|
contextual_scope_update handler to inject navigation events without
|
|
making an LLM call.
|
|
"""
|
|
from langchain_core.messages import SystemMessage # noqa: PLC0415
|
|
|
|
key = (user_id, session_id)
|
|
with self._lock:
|
|
entry = self._store.get(key)
|
|
if entry is None:
|
|
msgs: list[BaseMessage] = [SystemMessage(content=text)]
|
|
else:
|
|
_, existing = entry
|
|
msgs = list(existing) + [SystemMessage(content=text)]
|
|
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
|
|
self._store[key] = (time.monotonic(), capped)
|
|
|
|
|
|
class ContextualBufferProxy:
|
|
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
|
|
|
|
Returned by get_session_buffer() so callers can call
|
|
``proxy.append_system_message(text)`` without threading user_id/session_id
|
|
through every call site.
|
|
"""
|
|
|
|
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
|
|
self._buf = buf
|
|
self._user_id = user_id
|
|
self._session_id = session_id
|
|
|
|
def append_system_message(self, text: str) -> None:
|
|
self._buf.append_system_message(self._user_id, self._session_id, text)
|
|
|
|
|
|
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
|
session_buffer = _SessionBuffer()
|