"""Redis consumer — listens for chat requests and dispatches to deep_agent. Subscribes to a Redis pattern channel chat:request:* so it receives requests for ALL users. Each request is processed in a separate asyncio task. """ from __future__ import annotations import asyncio import json import logging from uuid import uuid4 from shared.db import async_session from shared.redis import redis_client, ws_out_channel from app.deep_agent import run_floating_stream, run_home_stream from app.memory_middleware import MemoryMiddleware from app.output_formatter import StreamFormatter from shared.ws_context import clear_current_user, set_current_user from app import tracing logger = logging.getLogger(__name__) def start_consumer() -> asyncio.Task: """Start the Redis consumer as a background asyncio task.""" return asyncio.create_task(_consumer_loop()) async def _consumer_loop() -> None: """Subscribe to chat:request:* and dispatch incoming frames.""" pubsub = redis_client.pubsub() await pubsub.psubscribe("chat:request:*") logger.info("redis_consumer: subscribed to chat:request:*") try: while True: message = await pubsub.get_message( ignore_subscribe_messages=True, timeout=1.0 ) if message is not None and message["type"] == "pmessage": frame = json.loads(message["data"]) asyncio.create_task(_dispatch(frame)) else: await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("redis_consumer: shutting down") finally: await pubsub.punsubscribe() await pubsub.aclose() async def _dispatch(frame: dict) -> None: """Route a chat request frame to the appropriate handler.""" frame_type = frame.get("type") user_id = frame.get("user_id") if not user_id: logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type")) return if frame_type == "home_request": await _handle_home_request(user_id, frame) elif frame_type == "floating_request": await _handle_floating_request(user_id, frame) else: logger.debug("redis_consumer: unknown frame type %r", frame_type) async def _publish_frame(user_id: str, frame_data: str) -> None: """Publish a frame to ws:out:{user_id} for the WS Gateway to forward.""" channel = ws_out_channel(user_id) await redis_client.publish(channel, frame_data) async def _handle_home_request(user_id: str, frame: dict) -> None: """Process a home_request — enrich with memory, run deep_agent, stream results.""" request_id = frame.get("request_id") or str(uuid4()) message: str = frame.get("message", "") session_id: str = frame.get("session_id") or str(uuid4()) logger.info( "redis_consumer: home_request user=%s req=%s msg=%s", user_id, request_id, message[:200], ) response_chunks: list[str] = [] with tracing.trace_span( name="home_request", user_id=user_id, session_id=session_id, trace_id=request_id, input=message, metadata={"message_preview": message[:200]}, tags=["home"], ) as span: langfuse_handler = tracing.get_langfuse_callback() # Enrich with memory context async with async_session() as db: memory = MemoryMiddleware(db) memory_context = await memory.enrich_context( user_id, message, trace_id=request_id, session_id=session_id, ) context: dict = { "conversation_history": frame.get("conversation_history", []), "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, **memory_context, } set_current_user(user_id) try: event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler) formatter = StreamFormatter(request_id=request_id) async for ws_frame in formatter.format(event_stream): await _publish_frame(user_id, ws_frame.model_dump_json()) if hasattr(ws_frame, "chunk"): response_chunks.append(ws_frame.chunk) except Exception as exc: logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc) finally: clear_current_user() # Link prompt and attach output preview tracing.link_prompt_to_trace(span, "home_system") response_text = "".join(response_chunks) span.update(output=response_text[:500] if response_text else None) tracing.flush() # Store episode async with async_session() as db: memory = MemoryMiddleware(db) await memory.store_episode( user_id, session_id, message, "".join(response_chunks), trace_id=request_id, ) async def _handle_floating_request(user_id: str, frame: dict) -> None: """Process a floating_request — enrich with memory, run deep_agent, stream results.""" request_id = frame.get("request_id") or str(uuid4()) message: str = frame.get("message", "") session_id: str = frame.get("session_id") or str(uuid4()) scope: dict = frame.get("scope", {}) logger.info( "redis_consumer: floating_request user=%s req=%s scope=%s msg=%s", user_id, request_id, json.dumps(scope)[:200], message[:200], ) response_chunks: list[str] = [] with tracing.trace_span( name="floating_request", user_id=user_id, session_id=session_id, trace_id=request_id, input=message, metadata={"message_preview": message[:200], "scope": scope}, tags=["floating"], ) as span: langfuse_handler = tracing.get_langfuse_callback() # Enrich with memory context async with async_session() as db: memory = MemoryMiddleware(db) memory_context = await memory.enrich_context( user_id, message, trace_id=request_id, session_id=session_id, ) context: dict = { "scope": scope, "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, **memory_context, } set_current_user(user_id) try: event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler) formatter = StreamFormatter(request_id=request_id) async for ws_frame in formatter.format(event_stream): await _publish_frame(user_id, ws_frame.model_dump_json()) if hasattr(ws_frame, "chunk"): response_chunks.append(ws_frame.chunk) except Exception as exc: logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc) finally: clear_current_user() # Link prompt and attach output preview tracing.link_prompt_to_trace(span, "floating_system") response_text = "".join(response_chunks) span.update(output=response_text[:500] if response_text else None) tracing.flush() # Store episode async with async_session() as db: memory = MemoryMiddleware(db) await memory.store_episode( user_id, session_id, message, "".join(response_chunks), trace_id=request_id, )