3 Commits

Author SHA1 Message Date
Roberto
70c19d3064 chore(contextual): purge residual floating WsFrame defs + output_formatter branch
After M6.5 deletion of run_floating_stream and the frame dispatch,
WsFrameType.floating_request/floating_domain, WsFloatingRequest,
WsFloatingDomain, WsFloatingScope, WsDomain, and the StreamFormatter's
floating_domain branch were left as dead protocol surface. Remove them,
along with the corresponding test cases in test_schemas_v3.py and
test_output_formatter.py.
2026-05-15 18:56:29 +02:00
Roberto
886730b47e test(contextual): remove floating-specific tests
Replaced by tests/test_contextual_*.py in M3.
No dedicated test_floating_*.py files existed; floating test
functions were embedded in test_deep_agent.py and test_ws_unified.py
and have been removed from those files.
2026-05-15 18:53:08 +02:00
Roberto
052c7e3741 refactor(contextual): drop floating WS frame, runner, and prompt fallback
contextual_request + contextual_scope_update are the only WS
flows for ad-hoc contextual chat now. Floating system prompt
constant removed; Langfuse 'floating_system' is deleted in a
separate manual step. Also removes floating-agent LLM slot from
llm.py and the associated LLM_MODEL_FLOATING_AGENT setting entry.
2026-05-15 18:53:01 +02:00
10 changed files with 11 additions and 835 deletions

View File

@@ -44,7 +44,7 @@ from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.core.agent_runner import trigger_pending_runs
from app.core.agent_session_buffer import session_buffer from app.core.agent_session_buffer import session_buffer
from app.core.brief_agent import run_home_brief, run_project_brief from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_contextual_stream, run_floating_stream, run_home_stream, run_task_brief_research_stream from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
@@ -161,11 +161,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_home_request(websocket, user_id, frame) _handle_home_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.brief_request: elif frame_type == WsFrameType.brief_request:
asyncio.create_task( asyncio.create_task(
_handle_brief_request(websocket, user_id, frame) _handle_brief_request(websocket, user_id, frame)
@@ -301,76 +296,6 @@ async def _handle_home_request(
) )
async def _handle_floating_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v8 Contextual Sidebar Handlers ─────────────────────────────────── # ── v8 Contextual Sidebar Handlers ───────────────────────────────────

View File

@@ -23,9 +23,8 @@ class Settings(BaseSettings):
LLM_EMBED_MODEL: str = "text-embedding-3-small" LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL. # Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing) LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream) LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner) LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner) LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs) LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)

View File

@@ -1,4 +1,4 @@
"""Single-agent runners for home and floating chat contexts.""" """Single-agent runners for home and contextual chat contexts."""
from __future__ import annotations from __future__ import annotations
@@ -7,7 +7,7 @@ import logging
import re import re
from datetime import date from datetime import date
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any, Literal from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
@@ -29,9 +29,6 @@ logger = logging.getLogger(__name__)
MAX_HISTORY_TURNS = 20 MAX_HISTORY_TURNS = 20
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
# Mapping of core-memory language values to natural-language names for prompts. # Mapping of core-memory language values to natural-language names for prompts.
_LANGUAGE_NAMES: dict[str, str] = { _LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish", "en": "English", "it": "Italian", "es": "Spanish",
@@ -354,44 +351,6 @@ For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_t
{request_context}\ {request_context}\
""" """
_FLOATING_SYSTEM_PROMPT = """\
You are adiuvAI's floating executive assistant.{user_identity}
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
# How you work
- Use tools before answering anything factual. Never guess.
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant.
- Match the user's tone preference. Default to warm-but-direct.
- When the user asks to remember, forget, or update something, use memory tools.
# Filter discipline
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
- The user's own name in the User profile block is for context only — it is NOT a default filter.
- When in doubt, omit `assignee` and return the global result.
# Output format
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
# Date filtering
{date_context}
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Behavioral hints
{proactive_hints}
# Request context
{request_context}\
"""
_CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions. _CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions.
Rules: Rules:
@@ -486,19 +445,6 @@ Stay terse — your principal is a busy executive.
{request_context}\ {request_context}\
""" """
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str: def _as_text(content: Any) -> str:
if content is None: if content is None:
return "" return ""
@@ -727,70 +673,6 @@ def _normalize_tagged_list_lines(text: str, message: str) -> str:
return "\n".join(output_lines) return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str: def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip() value = path_or_label.strip()
if value.startswith("/memories/"): if value.startswith("/memories/"):
@@ -971,168 +853,6 @@ def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)] return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_agent_llm("classifier")
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
lf = get_langfuse()
_, classifier_prompt_obj = get_prompt_or_fallback(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
)
# Extract user/session from context for Langfuse attribution
_debug = context.get("_debug") if isinstance(context, dict) else None
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=model_for_agent("classifier"),
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
response = await llm.ainvoke(classifier_messages)
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
else:
response = await llm.ainvoke(classifier_messages)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]: def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
if not history: if not history:
return [] return []
@@ -1461,25 +1181,6 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
return _normalize_tagged_list_lines(response, message) return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream( async def run_home_stream(
user_id: str, user_id: str,
message: str, message: str,
@@ -1526,71 +1227,6 @@ async def run_home_stream(
yield "token", normalized yield "token", normalized
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
brief_mode: bool = bool(context.get("brief_mode"))
briefing_context_text: str = str(context.get("briefing_context") or "").strip()
if brief_mode and briefing_context_text:
# Stage 2: inject briefing as ground truth context.
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
# before compile_prompt sees the remaining standard variables.
template, langfuse_prompt = get_prompt_or_fallback(
"task_brief_followup_system",
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
)
system_prompt = compile_prompt(
template, langfuse_prompt,
date_context=_datetime_context_injection(prepared_context).strip(),
language_instruction=_language_instruction(prepared_context).strip(),
user_identity=_user_identity_injection(prepared_context).strip(),
relational_memory=_relational_memory_injection(prepared_context).strip(),
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
request_context=_request_context_block(prepared_context),
briefing_context=briefing_context_text,
)
else:
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def run_contextual_stream( async def run_contextual_stream(
user_id: str, user_id: str,
message: str, message: str,
@@ -1599,8 +1235,8 @@ async def run_contextual_stream(
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the contextual agent for a single user turn. """Run the contextual agent for a single user turn.
Mirrors run_floating_stream's plumbing but injects the rendered scope Injects the rendered scope block into the system prompt and exposes
block into the system prompt and exposes the contextual tool set. the contextual tool set.
Note-edit tools (propose_note_edit) are intentionally excluded. Note-edit tools (propose_note_edit) are intentionally excluded.
*context contract*: callers MUST include ``context["_debug"]["session_id"]`` *context contract*: callers MUST include ``context["_debug"]["session_id"]``

View File

@@ -103,7 +103,6 @@ def get_llm(
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = { _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL, "classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL, "home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL, "unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL, "cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL, "brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,

View File

@@ -6,7 +6,7 @@ import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline). # Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile( _CANVAS_BLOCK_RE = re.compile(
@@ -31,7 +31,7 @@ def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
visible = visible.strip() visible = visible.strip()
return visible, canvas_content, canvas_kind return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
class StreamFormatter: class StreamFormatter:
@@ -47,14 +47,6 @@ class StreamFormatter:
started = False started = False
async for event_type, data in event_stream: async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token": if event_type != "token":
continue continue

View File

@@ -73,11 +73,9 @@ class WsFrameType(str, Enum):
device_hello = "device_hello" device_hello = "device_hello"
# ── v3 frame types ───────────────────────────────────────────────── # ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request" home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start" stream_start = "stream_start"
stream_text = "stream_text" stream_text = "stream_text"
stream_end = "stream_end" stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request" data_request = "data_request"
data_response = "data_response" data_response = "data_response"
mutation = "mutation" mutation = "mutation"
@@ -165,13 +163,6 @@ class FormatPrefsModel(BaseModel):
now_iso: str = "" now_iso: str = ""
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel): class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message.""" """Client → Server: Home chat message."""
@@ -181,15 +172,6 @@ class WsHomeRequest(BaseModel):
format_prefs: FormatPrefsModel | None = None format_prefs: FormatPrefsModel | None = None
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
format_prefs: FormatPrefsModel | None = None
class WsBriefRequest(BaseModel): class WsBriefRequest(BaseModel):
"""Client → Server: Request a plain-text brief (home or project).""" """Client → Server: Request a plain-text brief (home or project)."""
@@ -225,22 +207,6 @@ class WsStreamEnd(BaseModel):
mutations: list[dict[str, Any]] | None = None mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Config V2 ─────────────────────────────────────────────────── # ── Agent Config V2 ───────────────────────────────────────────────────

View File

@@ -12,11 +12,8 @@ from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import ( from app.core.deep_agent import (
_build_system_prompt, _build_system_prompt,
_datetime_context_injection, _datetime_context_injection,
_infer_floating_domain,
_normalize_tagged_list_lines, _normalize_tagged_list_lines,
_request_context_block, _request_context_block,
run_floating,
run_floating_stream,
run_home, run_home,
) )
@@ -75,57 +72,6 @@ async def test_run_home_uses_mocked_tool_result():
assert "Mock Task" in out assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
# _run_single_agent_stream uses ainvoke (not astream); the final token is
# the second LLM response which echoes the tool result.
token_events = [e for e in events if e[0] == "token"]
assert token_events, "Expected at least one token event"
combined = "".join(str(e[1]) for e in token_events)
assert "Mock Task" in combined
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines(): def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = ( raw = (
"Certo!\n\n" "Certo!\n\n"
@@ -162,139 +108,6 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
assert "<timeline>[tl-future]</timeline>" not in out assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events
# ── _datetime_context_injection ──────────────────────────────────────────────── # ── _datetime_context_injection ────────────────────────────────────────────────
def _fp(tz: str, now_iso: str) -> dict: def _fp(tz: str, now_iso: str) -> dict:

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import pytest import pytest
from app.core.output_formatter import StreamFormatter from app.core.output_formatter import StreamFormatter
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
async def _stream(*events: tuple[str, object]): async def _stream(*events: tuple[str, object]):
@@ -36,29 +36,6 @@ async def test_stream_formatter_text_stream() -> None:
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2")
frames = await _collect(
formatter,
_stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary"
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_formatter_ignores_unknown_events() -> None: async def test_stream_formatter_ignores_unknown_events() -> None:
formatter = StreamFormatter(request_id="req-3") formatter = StreamFormatter(request_id="req-3")

View File

@@ -4,12 +4,8 @@ import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.schemas import ( from app.schemas import (
WsDomain,
WsFrameType, WsFrameType,
WsHomeRequest, WsHomeRequest,
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
WsStreamEnd, WsStreamEnd,
WsStreamStart, WsStreamStart,
WsStreamText, WsStreamText,
@@ -22,11 +18,9 @@ from app.schemas import (
def test_v3_frame_types_exist(): def test_v3_frame_types_exist():
v3_types = [ v3_types = [
"home_request", "home_request",
"floating_request",
"stream_start", "stream_start",
"stream_text", "stream_text",
"stream_end", "stream_end",
"floating_domain",
"data_request", "data_request",
"data_response", "data_response",
"mutation", "mutation",
@@ -86,51 +80,6 @@ def test_home_request_requires_message():
WsHomeRequest.model_validate({"type": "home_request"}) WsHomeRequest.model_validate({"type": "home_request"})
# ── WsFloatingRequest ────────────────────────────────────────────────────
def test_floating_request_basic():
frame = WsFloatingRequest(
message="Summarise",
scope=WsFloatingScope(type="task", id="task-123"),
)
assert frame.type == WsFrameType.floating_request
assert frame.scope.type == "task"
assert frame.scope.id == "task-123"
def test_floating_request_scope_without_id():
frame = WsFloatingRequest(
message="Show all",
scope=WsFloatingScope(type="project"),
)
assert frame.scope.id is None
def test_floating_request_serializes():
frame = WsFloatingRequest(
message="Test",
scope=WsFloatingScope(type="note", id="n-1"),
)
data = frame.model_dump()
assert data["type"] == "floating_request"
assert data["scope"]["type"] == "note"
assert data["scope"]["id"] == "n-1"
def test_floating_request_invalid_scope_type():
with pytest.raises(ValidationError):
WsFloatingRequest(
message="X",
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
)
def test_floating_request_requires_scope():
with pytest.raises(ValidationError):
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
# ── WsStreamStart ───────────────────────────────────────────────────── # ── WsStreamStart ─────────────────────────────────────────────────────
@@ -189,51 +138,3 @@ def test_stream_end_deserializes():
assert frame.request_id == "r3" assert frame.request_id == "r3"
# ── WsFloatingDomain ─────────────────────────────────────────────────────
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain
assert frame.domain.type == "task"
def test_floating_domain_valid_domains():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_object_valid():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes():
d = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes():
raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -1,6 +1,6 @@
"""Integration tests for the unified WebSocket handler (Step 5). """Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and floating_request frames, Tests the device WS endpoint with home_request frames,
verifying that the correct v3 frame sequence is returned. verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency. LLM calls are mocked to avoid network dependency.
@@ -34,7 +34,7 @@ def _override_db(db_session):
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames.""" """Receive frames until stream_end or max_frames."""
frames = [] frames = []
for _ in range(max_frames): for _ in range(max_frames):
raw = ws.receive_text() raw = ws.receive_text()
@@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context):
yield "token", "Hello" yield "token", "Hello"
async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary"
# ── tests ───────────────────────────────────────────────────────────────────── # ── tests ─────────────────────────────────────────────────────────────────────
def test_home_request_produces_stream_frames(client): def test_home_request_produces_stream_frames(client):
@@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client):
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end) assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "floating_request",
"request_id": "p1",
"message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"},
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1"
def test_home_request_request_id_propagated(client): def test_home_request_request_id_propagated(client):
"""request_id in home_request is echoed in all response frames.""" """request_id in home_request is echoed in all response frames."""
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)