Files
api/services/chat/app/redis_consumer.py
Roberto Musso d856dfd28c refactor: deduplicate shared code into shared/ module
Move duplicated files from chat + batch-agent into shared/:
- shared/ws_context.py — Redis-based tool call round-trip
- shared/llm.py — LiteLLM factory (get_llm, embed)
- shared/agents/ — 4 domain agents (task, note, project, timeline)

Update all service imports to use shared.* instead of app.*.
Delete 12 duplicated files across both services.
2026-03-23 23:01:45 +01:00

210 lines
7.3 KiB
Python

"""Redis consumer — listens for chat requests and dispatches to deep_agent.
Subscribes to a Redis pattern channel chat:request:* so it receives
requests for ALL users. Each request is processed in a separate asyncio task.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from shared.db import async_session
from shared.redis import redis_client, ws_out_channel
from app.deep_agent import run_floating_stream, run_home_stream
from app.memory_middleware import MemoryMiddleware
from app.output_formatter import StreamFormatter
from shared.ws_context import clear_current_user, set_current_user
from app import tracing
logger = logging.getLogger(__name__)
def start_consumer() -> asyncio.Task:
"""Start the Redis consumer as a background asyncio task."""
return asyncio.create_task(_consumer_loop())
async def _consumer_loop() -> None:
"""Subscribe to chat:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("chat:request:*")
logger.info("redis_consumer: subscribed to chat:request:*")
try:
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message is not None and message["type"] == "pmessage":
frame = json.loads(message["data"])
asyncio.create_task(_dispatch(frame))
else:
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("redis_consumer: shutting down")
finally:
await pubsub.punsubscribe()
await pubsub.aclose()
async def _dispatch(frame: dict) -> None:
"""Route a chat request frame to the appropriate handler."""
frame_type = frame.get("type")
user_id = frame.get("user_id")
if not user_id:
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
return
if frame_type == "home_request":
await _handle_home_request(user_id, frame)
elif frame_type == "floating_request":
await _handle_floating_request(user_id, frame)
else:
logger.debug("redis_consumer: unknown frame type %r", frame_type)
async def _publish_frame(user_id: str, frame_data: str) -> None:
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, frame_data)
async def _handle_home_request(user_id: str, frame: dict) -> None:
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"redis_consumer: home_request user=%s req=%s msg=%s",
user_id, request_id, message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="home_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200]},
tags=["home"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "home_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)
async def _handle_floating_request(user_id: str, frame: dict) -> None:
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
user_id, request_id, json.dumps(scope)[:200], message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="floating_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200], "scope": scope},
tags=["floating"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "floating_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)