"""Redis consumer for the Batch Agent Service. Subscribes to batch:request:* (pattern) and dispatches: - journey_start → handle_journey_start - journey_message → handle_journey_message - agent_trigger → run_local_agent / run_cloud_agent Results are published back to ws:out:{user_id} via Redis. """ from __future__ import annotations import asyncio import json import logging from typing import Any from shared.redis import redis_client, batch_request_channel, ws_out_channel import app.tracing as tracing from app.ws_context import set_current_user, clear_current_user logger = logging.getLogger(__name__) async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None: """Publish a frame to the user's WS outbound channel.""" channel = ws_out_channel(user_id) await redis_client.publish(channel, json.dumps(payload)) async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None: """Handle a journey_start request from WS Gateway.""" from app.journey import handle_journey_start session_id = data.get("session_id", "") set_current_user(user_id) try: with tracing.trace_span( name="journey_start", user_id=user_id, session_id=session_id, input=data.get("directory", ""), metadata={"data_types": data.get("data_types", [])}, tags=["journey"], ) as span: langfuse_handler = tracing.get_langfuse_callback() reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler) span.update(output=reply.get("message", "")[:500]) await _publish_to_user(user_id, reply) tracing.flush() except Exception as exc: logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc) await _publish_to_user(user_id, { "type": "journey_reply", "session_id": session_id, "message": f"Journey setup failed: {exc}", "done": True, "prompt_template": None, }) finally: clear_current_user() async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None: """Handle a journey_message from WS Gateway.""" from app.journey import handle_journey_message session_id = data.get("session_id", "") set_current_user(user_id) try: with tracing.trace_span( name="journey_message", user_id=user_id, session_id=session_id, input=data.get("message", "")[:200], tags=["journey"], ) as span: langfuse_handler = tracing.get_langfuse_callback() reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler) span.update(output=reply.get("message", "")[:500]) await _publish_to_user(user_id, reply) tracing.flush() except Exception as exc: logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc) await _publish_to_user(user_id, { "type": "journey_reply", "session_id": session_id, "message": f"Journey processing failed: {exc}", "done": True, "prompt_template": None, }) finally: clear_current_user() async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None: """Handle an agent_trigger request from the REST route (forwarded via Redis).""" from app.agent_runner import run_local_agent run_context = data.get("run_context", {}) agent_id = run_context.get("agent_id", "") set_current_user(user_id) try: with tracing.trace_span( name="agent_trigger", user_id=user_id, trace_id=run_context.get("run_id"), input={"agent_id": agent_id, "directory": data.get("directory", "")}, metadata={"data_types": data.get("data_types", [])}, tags=["batch", "agent_run"], ) as span: langfuse_handler = tracing.get_langfuse_callback() await run_local_agent(user_id, data, langfuse_handler=langfuse_handler) span.update(output={"status": "completed"}) tracing.flush() except Exception as exc: logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc) await _publish_to_user(user_id, { "type": "run_complete", "status": "error", "run_context": run_context, }) finally: clear_current_user() async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None: """Route a batch request to the correct handler.""" msg_type = message_data.get("type", "") if msg_type == "journey_start": await _handle_journey_start(user_id, message_data) elif msg_type == "journey_message": await _handle_journey_message(user_id, message_data) elif msg_type == "agent_trigger": await _handle_agent_trigger(user_id, message_data) else: logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id) async def start_consumer() -> None: """Subscribe to batch:request:* and dispatch incoming frames.""" pubsub = redis_client.pubsub() await pubsub.psubscribe("batch:request:*") logger.info("batch-agent: subscribed to batch:request:*") try: async for message in pubsub.listen(): if message["type"] != "pmessage": continue channel: str = message["channel"] if isinstance(channel, bytes): channel = channel.decode() # Extract user_id from channel: batch:request:{user_id} parts = channel.split(":", 2) if len(parts) < 3: continue user_id = parts[2] raw = message["data"] if isinstance(raw, bytes): raw = raw.decode() try: data = json.loads(raw) except json.JSONDecodeError: logger.warning("batch-agent: invalid JSON on channel %s", channel) continue # Dispatch in a separate task to avoid blocking the consumer asyncio.create_task(_dispatch(user_id, data)) except asyncio.CancelledError: logger.info("batch-agent: consumer shutting down") finally: await pubsub.punsubscribe("batch:request:*")