fix floating stream empty responses with sanitizer-safe fallbacks
This commit is contained in:
@@ -42,6 +42,7 @@ _HOME_SINGLE_AGENT_SYSTEM = (
|
|||||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
@@ -221,6 +222,70 @@ def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
|||||||
return "\n".join(output_lines)
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||||
|
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||||
|
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup_fragment(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||||
|
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_floating_markup(text: str) -> str:
|
||||||
|
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
cleaned = _strip_floating_markup_fragment(text)
|
||||||
|
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
|
||||||
|
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||||
|
return "\n".join(line for line in lines if line)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||||
|
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||||
|
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||||
|
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||||
|
|
||||||
|
|
||||||
|
class _FloatingStreamSanitizer:
|
||||||
|
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pending = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||||
|
boundary = len(text)
|
||||||
|
|
||||||
|
last_lt = text.rfind("<")
|
||||||
|
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||||
|
boundary = min(boundary, last_lt)
|
||||||
|
|
||||||
|
last_lb = text.rfind("[")
|
||||||
|
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||||
|
boundary = min(boundary, last_lb)
|
||||||
|
|
||||||
|
if boundary == len(text):
|
||||||
|
return text, ""
|
||||||
|
return text[:boundary], text[boundary:]
|
||||||
|
|
||||||
|
def feed(self, chunk: str) -> str:
|
||||||
|
combined = f"{self._pending}{chunk}"
|
||||||
|
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||||
|
return _strip_floating_markup_fragment(safe_text)
|
||||||
|
|
||||||
|
def finalize(self) -> str:
|
||||||
|
# Drop dangling unfinished wrappers at the very end.
|
||||||
|
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||||
|
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||||
|
self._pending = ""
|
||||||
|
return _strip_floating_markup_fragment(tail)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_memory_label(path_or_label: str) -> str:
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
value = path_or_label.strip()
|
value = path_or_label.strip()
|
||||||
if value.startswith("/memories/"):
|
if value.startswith("/memories/"):
|
||||||
@@ -618,11 +683,20 @@ async def _run_single_agent_stream(
|
|||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
|
emitted_any = False
|
||||||
async for chunk in llm.astream(messages):
|
async for chunk in llm.astream(messages):
|
||||||
token = _as_text(getattr(chunk, "content", ""))
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
if token:
|
if token:
|
||||||
streamed_chars += len(token)
|
streamed_chars += len(token)
|
||||||
|
emitted_any = True
|
||||||
yield "token", token
|
yield "token", token
|
||||||
|
|
||||||
|
# Some providers return final text in `response.content` but stream no chunks.
|
||||||
|
if not emitted_any:
|
||||||
|
fallback_text = _as_text(response.content)
|
||||||
|
if fallback_text:
|
||||||
|
streamed_chars += len(fallback_text)
|
||||||
|
yield "token", fallback_text
|
||||||
logger.info(
|
logger.info(
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
@@ -696,7 +770,10 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
)
|
)
|
||||||
return response, domain
|
sanitized = _strip_floating_markup(response)
|
||||||
|
if not sanitized and response:
|
||||||
|
sanitized = _fallback_from_raw_floating_text(response)
|
||||||
|
return sanitized, domain
|
||||||
|
|
||||||
|
|
||||||
async def run_home_stream(
|
async def run_home_stream(
|
||||||
@@ -732,13 +809,34 @@ async def run_floating_stream(
|
|||||||
domain = await _infer_floating_domain(message, prepared_context)
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
yield "floating_domain", domain
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
|
emitted_sanitized = False
|
||||||
|
raw_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
):
|
):
|
||||||
yield event
|
event_type, data = event
|
||||||
|
if event_type != "token":
|
||||||
|
yield event
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_chunk = str(data or "")
|
||||||
|
raw_chunks.append(raw_chunk)
|
||||||
|
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||||
|
if sanitized_chunk:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", sanitized_chunk
|
||||||
|
|
||||||
|
tail = sanitizer.finalize()
|
||||||
|
if tail:
|
||||||
|
emitted_sanitized = True
|
||||||
|
yield "token", tail
|
||||||
|
|
||||||
|
if not emitted_sanitized and raw_chunks:
|
||||||
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||||
|
|
||||||
|
|
||||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
|
|||||||
@@ -9,7 +9,13 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
from app.core.deep_agent import _infer_floating_domain, _normalize_tagged_list_lines, run_floating_stream, run_home
|
from app.core.deep_agent import (
|
||||||
|
_infer_floating_domain,
|
||||||
|
_normalize_tagged_list_lines,
|
||||||
|
run_floating,
|
||||||
|
run_floating_stream,
|
||||||
|
run_home,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _FakeTool:
|
class _FakeTool:
|
||||||
@@ -147,3 +153,136 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
|
|||||||
assert "<timeline>[tl-next]</timeline>" in out
|
assert "<timeline>[tl-next]</timeline>" in out
|
||||||
assert "<timeline>[tl-old]</timeline>" not in out
|
assert "<timeline>[tl-old]</timeline>" not in out
|
||||||
assert "<timeline>[tl-future]</timeline>" not in out
|
assert "<timeline>[tl-future]</timeline>" not in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return (
|
||||||
|
"Hai 1 task:\\n"
|
||||||
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "<task>" not in text
|
||||||
|
assert "</task>" not in text
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "Hai 1 task:\\n"
|
||||||
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
||||||
|
combined = "".join(token_events)
|
||||||
|
assert "<task>" not in combined
|
||||||
|
assert "</task>" not in combined
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
||||||
|
class _NoChunkLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(content="No notes found.")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
if False:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali sono le note?",
|
||||||
|
{"scope": {"type": "note"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0][0] == "floating_domain"
|
||||||
|
assert ("token", "No notes found.") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert ("token", "No results found.") in events
|
||||||
|
|||||||
Reference in New Issue
Block a user