Move duplicated files from chat + batch-agent into shared/: - shared/ws_context.py — Redis-based tool call round-trip - shared/llm.py — LiteLLM factory (get_llm, embed) - shared/agents/ — 4 domain agents (task, note, project, timeline) Update all service imports to use shared.* instead of app.*. Delete 12 duplicated files across both services.
210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
|
|
|
|
Subscribes to a Redis pattern channel chat:request:* so it receives
|
|
requests for ALL users. Each request is processed in a separate asyncio task.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from uuid import uuid4
|
|
|
|
from shared.db import async_session
|
|
from shared.redis import redis_client, ws_out_channel
|
|
|
|
from app.deep_agent import run_floating_stream, run_home_stream
|
|
from app.memory_middleware import MemoryMiddleware
|
|
from app.output_formatter import StreamFormatter
|
|
from shared.ws_context import clear_current_user, set_current_user
|
|
from app import tracing
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def start_consumer() -> asyncio.Task:
|
|
"""Start the Redis consumer as a background asyncio task."""
|
|
return asyncio.create_task(_consumer_loop())
|
|
|
|
|
|
async def _consumer_loop() -> None:
|
|
"""Subscribe to chat:request:* and dispatch incoming frames."""
|
|
pubsub = redis_client.pubsub()
|
|
await pubsub.psubscribe("chat:request:*")
|
|
logger.info("redis_consumer: subscribed to chat:request:*")
|
|
|
|
try:
|
|
while True:
|
|
message = await pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
if message is not None and message["type"] == "pmessage":
|
|
frame = json.loads(message["data"])
|
|
asyncio.create_task(_dispatch(frame))
|
|
else:
|
|
await asyncio.sleep(0.01)
|
|
except asyncio.CancelledError:
|
|
logger.info("redis_consumer: shutting down")
|
|
finally:
|
|
await pubsub.punsubscribe()
|
|
await pubsub.aclose()
|
|
|
|
|
|
async def _dispatch(frame: dict) -> None:
|
|
"""Route a chat request frame to the appropriate handler."""
|
|
frame_type = frame.get("type")
|
|
user_id = frame.get("user_id")
|
|
|
|
if not user_id:
|
|
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
|
|
return
|
|
|
|
if frame_type == "home_request":
|
|
await _handle_home_request(user_id, frame)
|
|
elif frame_type == "floating_request":
|
|
await _handle_floating_request(user_id, frame)
|
|
else:
|
|
logger.debug("redis_consumer: unknown frame type %r", frame_type)
|
|
|
|
|
|
async def _publish_frame(user_id: str, frame_data: str) -> None:
|
|
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
|
|
channel = ws_out_channel(user_id)
|
|
await redis_client.publish(channel, frame_data)
|
|
|
|
|
|
async def _handle_home_request(user_id: str, frame: dict) -> None:
|
|
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
|
|
logger.info(
|
|
"redis_consumer: home_request user=%s req=%s msg=%s",
|
|
user_id, request_id, message[:200],
|
|
)
|
|
|
|
response_chunks: list[str] = []
|
|
|
|
with tracing.trace_span(
|
|
name="home_request",
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
trace_id=request_id,
|
|
input=message,
|
|
metadata={"message_preview": message[:200]},
|
|
tags=["home"],
|
|
) as span:
|
|
langfuse_handler = tracing.get_langfuse_callback()
|
|
|
|
# Enrich with memory context
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id, message,
|
|
trace_id=request_id, session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"conversation_history": frame.get("conversation_history", []),
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
**memory_context,
|
|
}
|
|
|
|
set_current_user(user_id)
|
|
try:
|
|
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await _publish_frame(user_id, ws_frame.model_dump_json())
|
|
if hasattr(ws_frame, "chunk"):
|
|
response_chunks.append(ws_frame.chunk)
|
|
except Exception as exc:
|
|
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
|
finally:
|
|
clear_current_user()
|
|
|
|
# Link prompt and attach output preview
|
|
tracing.link_prompt_to_trace(span, "home_system")
|
|
response_text = "".join(response_chunks)
|
|
span.update(output=response_text[:500] if response_text else None)
|
|
|
|
tracing.flush()
|
|
|
|
# Store episode
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks),
|
|
trace_id=request_id,
|
|
)
|
|
|
|
|
|
async def _handle_floating_request(user_id: str, frame: dict) -> None:
|
|
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
|
|
request_id = frame.get("request_id") or str(uuid4())
|
|
message: str = frame.get("message", "")
|
|
session_id: str = frame.get("session_id") or str(uuid4())
|
|
scope: dict = frame.get("scope", {})
|
|
|
|
logger.info(
|
|
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
|
|
user_id, request_id, json.dumps(scope)[:200], message[:200],
|
|
)
|
|
|
|
response_chunks: list[str] = []
|
|
|
|
with tracing.trace_span(
|
|
name="floating_request",
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
trace_id=request_id,
|
|
input=message,
|
|
metadata={"message_preview": message[:200], "scope": scope},
|
|
tags=["floating"],
|
|
) as span:
|
|
langfuse_handler = tracing.get_langfuse_callback()
|
|
|
|
# Enrich with memory context
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
memory_context = await memory.enrich_context(
|
|
user_id, message,
|
|
trace_id=request_id, session_id=session_id,
|
|
)
|
|
|
|
context: dict = {
|
|
"scope": scope,
|
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
**memory_context,
|
|
}
|
|
|
|
set_current_user(user_id)
|
|
try:
|
|
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
|
formatter = StreamFormatter(request_id=request_id)
|
|
async for ws_frame in formatter.format(event_stream):
|
|
await _publish_frame(user_id, ws_frame.model_dump_json())
|
|
if hasattr(ws_frame, "chunk"):
|
|
response_chunks.append(ws_frame.chunk)
|
|
except Exception as exc:
|
|
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
|
finally:
|
|
clear_current_user()
|
|
|
|
# Link prompt and attach output preview
|
|
tracing.link_prompt_to_trace(span, "floating_system")
|
|
response_text = "".join(response_chunks)
|
|
span.update(output=response_text[:500] if response_text else None)
|
|
|
|
tracing.flush()
|
|
|
|
# Store episode
|
|
async with async_session() as db:
|
|
memory = MemoryMiddleware(db)
|
|
await memory.store_episode(
|
|
user_id, session_id, message, "".join(response_chunks),
|
|
trace_id=request_id,
|
|
)
|