fix: langfuse v4 SDK compatibility and pass user message as trace input

This commit is contained in:
Roberto Musso
2026-03-23 00:23:59 +01:00
parent 0d5fa3e569
commit 0b491b3643
11 changed files with 330 additions and 190 deletions

View File

@@ -528,7 +528,9 @@ def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) ->
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
async def _infer_floating_domain(
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
@@ -538,10 +540,14 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
}
try:
llm = get_llm()
classifier_prompt = _get_system_prompt(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
response = await llm.ainvoke(
[
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
SystemMessage(content=classifier_prompt),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
@@ -784,7 +790,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any], *, langf
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent(
user_id=user_id,
@@ -835,7 +841,7 @@ async def run_floating_stream(
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
yield "floating_domain", domain
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)

View File

@@ -31,6 +31,11 @@ logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialise Langfuse tracing (no-op if keys are missing)
from app.tracing import init_langfuse
init_langfuse()
# Start Redis consumer in background
from app.redis_consumer import start_consumer

View File

@@ -85,52 +85,51 @@ async def _handle_home_request(user_id: str, frame: dict) -> None:
user_id, request_id, message[:200],
)
# Create Langfuse trace
trace = tracing.create_trace(
response_chunks: list[str] = []
with tracing.trace_span(
name="home_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200]},
tags=["home"],
)
langfuse_handler = tracing.get_langfuse_callback(
trace=trace, span_name="home_agent",
)
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
set_current_user(user_id)
try:
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and flush trace
if trace is not None:
tracing.link_prompt_to_trace(trace, "home_system")
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "home_system")
response_text = "".join(response_chunks)
trace.update(output=response_text[:500] if response_text else None)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
@@ -154,52 +153,51 @@ async def _handle_floating_request(user_id: str, frame: dict) -> None:
user_id, request_id, json.dumps(scope)[:200], message[:200],
)
# Create Langfuse trace
trace = tracing.create_trace(
response_chunks: list[str] = []
with tracing.trace_span(
name="floating_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200], "scope": scope},
tags=["floating"],
)
langfuse_handler = tracing.get_langfuse_callback(
trace=trace, span_name="floating_agent",
)
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
set_current_user(user_id)
try:
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and flush trace
if trace is not None:
tracing.link_prompt_to_trace(trace, "floating_system")
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "floating_system")
response_text = "".join(response_chunks)
trace.update(output=response_text[:500] if response_text else None)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode

View File

@@ -1,137 +1,156 @@
"""Langfuse tracing & prompt management for the Chat Service.
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
Provides:
- ``langfuse`` — singleton Langfuse client (lazy, no-op when keys are missing)
- ``create_trace()`` — start a new trace for a chat request
- ``get_langfuse_callback()`` — LangChain callback handler for a trace/span
- ``init_langfuse()`` — initialise the singleton client at startup
- ``trace_span()`` — context manager that creates a trace + span
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
- ``flush()`` — ensure all events are sent before shutdown
- ``flush()`` / ``shutdown()`` — lifecycle management
All functions gracefully degrade to no-ops when Langfuse is not configured,
so the service works identically with or without observability keys.
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any
from shared.config import settings
logger = logging.getLogger(__name__)
# ── Lazy singleton ───────────────────────────────────────────────────────
# ── State ────────────────────────────────────────────────────────────────
_langfuse_client: Any | None = None
_langfuse_disabled: bool = False
_initialised: bool = False
_disabled: bool = False
def _is_configured() -> bool:
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
def _get_langfuse() -> Any | None:
"""Return the Langfuse client singleton, or None if not configured."""
global _langfuse_client, _langfuse_disabled
def init_langfuse() -> None:
"""Initialise the Langfuse singleton. Call once at startup."""
global _initialised, _disabled
if _langfuse_disabled:
return None
if _langfuse_client is not None:
return _langfuse_client
if _initialised or _disabled:
return
if not _is_configured():
_langfuse_disabled = True
_disabled = True
logger.info("tracing: Langfuse keys not set — tracing disabled")
return None
return
try:
from langfuse import Langfuse
_langfuse_client = Langfuse(
Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
_initialised = True
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
return _langfuse_client
except Exception as exc:
_langfuse_disabled = True
_disabled = True
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
def _get_client() -> Any | None:
"""Return the singleton Langfuse client, or *None* if disabled."""
if _disabled:
return None
if not _initialised:
init_langfuse()
if _disabled:
return None
try:
from langfuse import get_client
return get_client()
except Exception:
return None
# ── Trace lifecycle ──────────────────────────────────────────────────────
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
def create_trace(
class _NullSpan:
"""Drop-in replacement when Langfuse is disabled."""
def update(self, **_: Any) -> None: ...
def set_trace_io(self, **_: Any) -> None: ...
def score_trace(self, **_: Any) -> None: ...
# ── Trace context manager ───────────────────────────────────────────────
@contextmanager
def trace_span(
*,
name: str,
user_id: str,
session_id: str | None = None,
trace_id: str | None = None,
input: Any = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> Any | None:
"""Create a Langfuse trace. Returns the trace object, or None if disabled."""
lf = _get_langfuse()
):
"""Context manager that creates a Langfuse trace/span.
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
A ``CallbackHandler`` created inside this block auto-inherits the trace
context, so there is no need to pass trace IDs manually.
"""
lf = _get_client()
if lf is None:
return None
yield _NullSpan()
return
try:
return lf.trace(
id=trace_id,
from langfuse import Langfuse, propagate_attributes
trace_ctx: dict[str, str] = {}
if trace_id is not None:
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
with lf.start_as_current_observation(
as_type="span",
name=name,
user_id=user_id,
session_id=session_id,
input=input,
metadata=metadata or {},
tags=tags or [],
)
**({"trace_context": trace_ctx} if trace_ctx else {}),
) as span:
with propagate_attributes(
user_id=user_id,
session_id=session_id,
tags=tags or [],
):
yield span
except Exception as exc:
logger.warning("tracing: create_trace failed: %s", exc)
return None
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
yield _NullSpan()
# ── LangChain callback handler ──────────────────────────────────────────
def get_langfuse_callback(
*,
trace_id: str | None = None,
trace: Any | None = None,
span_name: str | None = None,
update_parent: bool = True,
) -> Any | None:
"""Return a ``CallbackHandler`` wired to an existing trace.
def get_langfuse_callback() -> Any | None:
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
This handler is passed to LangChain's ``ainvoke`` / ``astream`` as a
callback so every LLM generation and tool call is automatically
captured as a nested span inside the trace.
If both *trace* and *trace_id* are given, *trace* takes precedence.
Returns None when Langfuse is disabled.
Must be called inside a ``trace_span()`` block for proper linking.
Returns *None* when Langfuse is disabled.
"""
lf = _get_langfuse()
if lf is None:
if _disabled and not _initialised:
return None
try:
from langfuse.callback import CallbackHandler
kwargs: dict[str, Any] = {
"secret_key": settings.LANGFUSE_SECRET_KEY,
"public_key": settings.LANGFUSE_PUBLIC_KEY,
"host": settings.LANGFUSE_HOST,
"update_parent": update_parent,
}
if trace is not None:
kwargs["trace_id"] = trace.id
elif trace_id is not None:
kwargs["trace_id"] = trace_id
if span_name:
kwargs["root_span"] = span_name
return CallbackHandler(**kwargs)
from langfuse.langchain import CallbackHandler
return CallbackHandler()
except Exception as exc:
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
return None
@@ -152,21 +171,8 @@ def get_prompt(
Returns the compiled prompt string, or *fallback* if the prompt is not
found or Langfuse is disabled.
Parameters
----------
name : str
Prompt name as registered in Langfuse.
version : int, optional
Pin to a specific version; omit for the latest production version.
label : str, optional
Fetch by label (e.g. ``"production"``, ``"staging"``).
fallback : str, optional
Value returned when the prompt cannot be fetched.
cache_ttl_seconds : int
How long to cache the prompt locally (default 5 min).
"""
lf = _get_langfuse()
lf = _get_client()
if lf is None:
return fallback
@@ -187,20 +193,15 @@ def get_prompt(
def link_prompt_to_trace(
trace: Any,
span: Any,
prompt_name: str,
*,
version: int | None = None,
label: str | None = None,
) -> None:
"""Attach a Langfuse prompt reference to a trace/generation.
Call this *after* creating a generation on the trace to associate the
prompt that was used. The prompt object is fetched and linked so
Langfuse can display prompt→trace associations in the dashboard.
"""
lf = _get_langfuse()
if lf is None or trace is None:
"""Attach prompt metadata to a span/trace."""
lf = _get_client()
if lf is None or isinstance(span, _NullSpan):
return
try:
@@ -210,7 +211,7 @@ def link_prompt_to_trace(
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
trace.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
except Exception as exc:
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
@@ -226,12 +227,12 @@ def score_trace(
comment: str | None = None,
) -> None:
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
lf = _get_langfuse()
lf = _get_client()
if lf is None:
return
try:
lf.score(trace_id=trace_id, name=name, value=value, comment=comment)
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
except Exception as exc:
logger.warning("tracing: score_trace failed: %s", exc)
@@ -240,22 +241,24 @@ def score_trace(
def flush() -> None:
"""Flush pending Langfuse events. Call this on service shutdown."""
if _langfuse_client is not None:
"""Flush pending Langfuse events."""
lf = _get_client()
if lf is not None:
try:
_langfuse_client.flush()
lf.flush()
except Exception as exc:
logger.warning("tracing: flush failed: %s", exc)
def shutdown() -> None:
"""Flush and close the Langfuse client."""
global _langfuse_client, _langfuse_disabled
if _langfuse_client is not None:
global _initialised, _disabled
lf = _get_client()
if lf is not None:
try:
_langfuse_client.flush()
_langfuse_client.shutdown()
lf.flush()
lf.shutdown()
except Exception as exc:
logger.warning("tracing: shutdown failed: %s", exc)
_langfuse_client = None
_langfuse_disabled = False
_initialised = False
_disabled = False

View File

@@ -14,4 +14,4 @@ langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
langfuse>=2.0.0
langfuse>=3.0.0