- Langfuse SDK v4: fix prompt-to-trace linking (as_type=generation) - tracing: compile_prompt with Langfuse managed prompt fallback - journey: remove journey CLI subcommand (keep only interactive) - LLM: add service-specific llm modules for batch-agent and chat - gitignore: exclude eval private test data - config: add LANGFUSE settings to shared config
184 lines
6.7 KiB
Python
184 lines
6.7 KiB
Python
"""Redis consumer for the Batch Agent Service.
|
|
|
|
Subscribes to batch:request:* (pattern) and dispatches:
|
|
- journey_start → handle_journey_start
|
|
- journey_message → handle_journey_message
|
|
- agent_trigger → run_local_agent / run_cloud_agent
|
|
|
|
Results are published back to ws:out:{user_id} via Redis.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
|
|
|
import app.tracing as tracing
|
|
from shared.ws_context import set_current_user, clear_current_user
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
|
"""Publish a frame to the user's WS outbound channel."""
|
|
channel = ws_out_channel(user_id)
|
|
await redis_client.publish(channel, json.dumps(payload))
|
|
|
|
|
|
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
|
"""Handle a journey_start request from WS Gateway."""
|
|
from app.journey import handle_journey_start
|
|
|
|
session_id = data.get("session_id", "")
|
|
set_current_user(user_id)
|
|
try:
|
|
with tracing.trace_span(
|
|
name="journey_start",
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
input=data.get("directory", ""),
|
|
metadata={"data_types": data.get("data_types", [])},
|
|
tags=["journey"],
|
|
) as span:
|
|
langfuse_handler = tracing.get_langfuse_callback()
|
|
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
|
|
tracing.link_prompt_to_trace(span, "journey_system")
|
|
span.update(output=reply.get("message", "")[:500])
|
|
await _publish_to_user(user_id, reply)
|
|
tracing.flush()
|
|
except Exception as exc:
|
|
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
|
await _publish_to_user(user_id, {
|
|
"type": "journey_reply",
|
|
"session_id": session_id,
|
|
"message": f"Journey setup failed: {exc}",
|
|
"done": True,
|
|
"prompt_template": None,
|
|
})
|
|
finally:
|
|
clear_current_user()
|
|
|
|
|
|
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
|
"""Handle a journey_message from WS Gateway."""
|
|
from app.journey import handle_journey_message
|
|
|
|
session_id = data.get("session_id", "")
|
|
set_current_user(user_id)
|
|
try:
|
|
with tracing.trace_span(
|
|
name="journey_message",
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
input=data.get("message", "")[:200],
|
|
tags=["journey"],
|
|
) as span:
|
|
langfuse_handler = tracing.get_langfuse_callback()
|
|
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
|
|
tracing.link_prompt_to_trace(span, "journey_system")
|
|
span.update(output=reply.get("message", "")[:500])
|
|
await _publish_to_user(user_id, reply)
|
|
tracing.flush()
|
|
except Exception as exc:
|
|
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
|
await _publish_to_user(user_id, {
|
|
"type": "journey_reply",
|
|
"session_id": session_id,
|
|
"message": f"Journey processing failed: {exc}",
|
|
"done": True,
|
|
"prompt_template": None,
|
|
})
|
|
finally:
|
|
clear_current_user()
|
|
|
|
|
|
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
|
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
|
from app.agent_runner import run_local_agent
|
|
|
|
run_context = data.get("run_context", {})
|
|
agent_id = run_context.get("agent_id", "")
|
|
set_current_user(user_id)
|
|
try:
|
|
with tracing.trace_span(
|
|
name="agent_trigger",
|
|
user_id=user_id,
|
|
trace_id=run_context.get("run_id"),
|
|
input={"agent_id": agent_id, "directory": data.get("directory", "")},
|
|
metadata={"data_types": data.get("data_types", [])},
|
|
tags=["batch", "agent_run"],
|
|
) as span:
|
|
langfuse_handler = tracing.get_langfuse_callback()
|
|
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
|
|
tracing.link_prompt_to_trace(span, "batch_processing")
|
|
span.update(output={"status": "completed"})
|
|
tracing.flush()
|
|
except Exception as exc:
|
|
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
|
await _publish_to_user(user_id, {
|
|
"type": "run_complete",
|
|
"status": "error",
|
|
"run_context": run_context,
|
|
})
|
|
finally:
|
|
clear_current_user()
|
|
|
|
|
|
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
|
"""Route a batch request to the correct handler."""
|
|
msg_type = message_data.get("type", "")
|
|
|
|
if msg_type == "journey_start":
|
|
await _handle_journey_start(user_id, message_data)
|
|
elif msg_type == "journey_message":
|
|
await _handle_journey_message(user_id, message_data)
|
|
elif msg_type == "agent_trigger":
|
|
await _handle_agent_trigger(user_id, message_data)
|
|
elif msg_type == "device_online":
|
|
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
|
|
else:
|
|
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
|
|
|
|
|
async def start_consumer() -> None:
|
|
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
|
pubsub = redis_client.pubsub()
|
|
await pubsub.psubscribe("batch:request:*")
|
|
logger.info("batch-agent: subscribed to batch:request:*")
|
|
|
|
try:
|
|
async for message in pubsub.listen():
|
|
if message["type"] != "pmessage":
|
|
continue
|
|
|
|
channel: str = message["channel"]
|
|
if isinstance(channel, bytes):
|
|
channel = channel.decode()
|
|
|
|
# Extract user_id from channel: batch:request:{user_id}
|
|
parts = channel.split(":", 2)
|
|
if len(parts) < 3:
|
|
continue
|
|
user_id = parts[2]
|
|
|
|
raw = message["data"]
|
|
if isinstance(raw, bytes):
|
|
raw = raw.decode()
|
|
|
|
try:
|
|
data = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
|
continue
|
|
|
|
# Dispatch in a separate task to avoid blocking the consumer
|
|
asyncio.create_task(_dispatch(user_id, data))
|
|
except asyncio.CancelledError:
|
|
logger.info("batch-agent: consumer shutting down")
|
|
finally:
|
|
await pubsub.punsubscribe("batch:request:*")
|