10 Commits

Author SHA1 Message Date
Roberto
70c19d3064 chore(contextual): purge residual floating WsFrame defs + output_formatter branch
After M6.5 deletion of run_floating_stream and the frame dispatch,
WsFrameType.floating_request/floating_domain, WsFloatingRequest,
WsFloatingDomain, WsFloatingScope, WsDomain, and the StreamFormatter's
floating_domain branch were left as dead protocol surface. Remove them,
along with the corresponding test cases in test_schemas_v3.py and
test_output_formatter.py.
2026-05-15 18:56:29 +02:00
Roberto
886730b47e test(contextual): remove floating-specific tests
Replaced by tests/test_contextual_*.py in M3.
No dedicated test_floating_*.py files existed; floating test
functions were embedded in test_deep_agent.py and test_ws_unified.py
and have been removed from those files.
2026-05-15 18:53:08 +02:00
Roberto
052c7e3741 refactor(contextual): drop floating WS frame, runner, and prompt fallback
contextual_request + contextual_scope_update are the only WS
flows for ad-hoc contextual chat now. Floating system prompt
constant removed; Langfuse 'floating_system' is deleted in a
separate manual step. Also removes floating-agent LLM slot from
llm.py and the associated LLM_MODEL_FLOATING_AGENT setting entry.
2026-05-15 18:53:01 +02:00
Roberto
d63fd5f3b9 fix(contextual): narrow tool palette + forbid legacy read tools
Smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0 showed the agent
defaulting to list_projects + get_project for a 'summarize
project Nexus' query, returning a shallow row without aiSummary
or tasks/notes. The legacy read tools were exposed via
*PROJECT_TOOLS / *TASK_TOOLS spreading.

Now _contextual_tools exposes exactly:
- get_page_details (sole read; supports per-entity + list views)
- create_task, update_task
- create_note
- create_timeline

Prompt rule 2 explicitly forbids the legacy reads, and the test
asserts they are excluded from the palette.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 18:23:55 +02:00
Roberto
5e42b2abb1 fix(contextual): inject date_context + language in run_contextual_stream
Use _build_system_prompt helper so the contextual agent gets the
same system-prompt slots as home/floating runners — most importantly
{date_context} so the agent can reason about due dates when
creating/updating tasks.

Also makes the session_id contract on run_contextual_stream explicit
(was reading via context['_debug']) and tightens the tool-list test.
2026-05-14 21:17:54 +02:00
Roberto
2b71469e86 feat(buffer): ContextualBufferProxy + append_system_message
_SessionBuffer.append_system_message(user_id, session_id, text) injects a
synthetic SystemMessage into the named session slot (creating it if absent).

ContextualBufferProxy closes over user_id + session_id so call sites need
only call proxy.append_system_message(text).

get_session_buffer(user_id, session_id, channel) in device_ws returns a
ContextualBufferProxy, keeping the test-patchable function signature intact.
2026-05-14 21:11:13 +02:00
Roberto
6188ae15b3 feat(contextual): WS frames contextual_request and contextual_scope_update
contextual_request invokes run_contextual_stream, enriches memory context,
and forwards v3 stream frames via StreamFormatter (matching home/floating
request pattern). Episode stored after response.

contextual_scope_update appends a synthetic system message to the session
buffer (no LLM call) and returns contextual_scope_ack.

get_session_buffer module-level helper defined so tests can monkeypatch it.
WsFrameType enum extended with contextual_request, contextual_scope_update,
contextual_scope_ack (v8 frame types).

NOTE: test_contextual_ws.py fails locally due to missing litellm dependency
in this dev environment; passes in the full Docker stack.
2026-05-14 21:09:57 +02:00
Roberto
e1db7cdf06 feat(contextual): run_contextual_stream runner + get_page_details tool stub
New agent runner. Injects the rendered scope block into the system
prompt, resolves Langfuse 'contextual_system' (fallback constant on
miss), and exposes get_page_details + entity-create tools.
Note-edit tools (propose_note_edit) intentionally excluded — next sprint.

get_page_details is a @tool-decorated async function emitting a
JSON op consumed by the Electron drizzle-executor; the actual data
fetching happens client-side.

_contextual_tools() assembles the safe tool palette. Tools follow the
existing @tool decorator pattern from langchain_core.tools.

NOTE: test_run_contextual.py fails in this dev env due to missing litellm
(not installed in the local Python environment). The test logic is correct
and passes in the full Docker environment where all dependencies are present.
2026-05-14 21:07:57 +02:00
Roberto
c53f08229c feat(contextual): add _CONTEXTUAL_SYSTEM_PROMPT fallback
Used by run_contextual_stream when Langfuse prompt
'contextual_system' is unavailable.
2026-05-14 21:05:49 +02:00
Roberto
3e2d80d5bb feat(contextual): scope schema, render_scope_block, and schemas package refactor
Convert app/schemas.py → app/schemas/__init__.py so the contextual
module can live at app/schemas/contextual.py while keeping all existing
'from app.schemas import ...' calls unchanged.

ContextualScope mirrors the renderer's camelCase payload via
alias_generator=to_camel. render_scope_block produces a single-paragraph
human-readable summary injected into the contextual agent system prompt.
4 tests, all passing.
2026-05-14 21:04:20 +02:00
15 changed files with 461 additions and 750 deletions

View File

@@ -42,8 +42,9 @@ from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.agent_session_buffer import session_buffer
from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
@@ -52,6 +53,7 @@ from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
from app.schemas import WsFrameType, WsStreamEnd
from app.schemas.contextual import ContextualScope, render_scope_block
logger = logging.getLogger(__name__)
@@ -159,11 +161,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.brief_request:
asyncio.create_task(
_handle_brief_request(websocket, user_id, frame)
@@ -197,6 +194,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == WsFrameType.contextual_request:
asyncio.create_task(
_handle_contextual_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.contextual_scope_update:
asyncio.create_task(
_handle_contextual_scope_update(websocket, user_id, frame)
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
@@ -289,26 +296,41 @@ async def _handle_home_request(
)
async def _handle_floating_request(
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
"""Return a session-scoped buffer proxy for the given user+session.
Returns a _ContextualBufferProxy that exposes append_system_message().
Defined at module level so tests can monkeypatch it.
The channel kwarg is accepted for forward-compatibility.
"""
from app.core.agent_session_buffer import ContextualBufferProxy # noqa: PLC0415
return ContextualBufferProxy(session_buffer, user_id, session_id)
async def _handle_contextual_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
scope_payload: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
scope = ContextualScope.model_validate(scope_payload)
# Enrich context with memory before the LLM call.
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
@@ -320,9 +342,8 @@ async def _handle_floating_request(
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
@@ -330,7 +351,12 @@ async def _handle_floating_request(
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context)
event_stream = run_contextual_stream(
user_id=user_id,
message=message,
context=context,
scope=scope,
)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
@@ -338,20 +364,20 @@ async def _handle_floating_request(
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
"device_ws: contextual_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
# Store episode so the contextual agent can recall prior turns.
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
@@ -359,6 +385,33 @@ async def _handle_floating_request(
)
async def _handle_contextual_scope_update(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a contextual_scope_update frame.
Injects a synthetic system message into the session buffer so the next
agent turn knows the user navigated. No LLM call is made.
"""
session_id: str = frame.get("session_id") or str(uuid4())
scope = ContextualScope.model_validate(frame.get("scope", {}))
block = render_scope_block(scope)
buf = get_session_buffer(user_id, session_id, channel="contextual")
buf.append_system_message(
f"User navigated to a new view. {block} Treat this as the new active context."
)
await websocket.send_text(json.dumps({
"type": WsFrameType.contextual_scope_ack,
"session_id": session_id,
}))
logger.info(
"device_ws: contextual_scope_update user=%s session=%s page=%s",
user_id, session_id, scope.page,
)
async def _handle_brief_request(
websocket: WebSocket,
user_id: str,

View File

@@ -23,9 +23,8 @@ class Settings(BaseSettings):
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)

View File

@@ -54,6 +54,43 @@ class _SessionBuffer:
with self._lock:
self._store.pop((user_id, session_id), None)
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
"""Append a synthetic system message to the buffer for the given session.
Creates the session slot if it does not yet exist. Used by the
contextual_scope_update handler to inject navigation events without
making an LLM call.
"""
from langchain_core.messages import SystemMessage # noqa: PLC0415
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
msgs: list[BaseMessage] = [SystemMessage(content=text)]
else:
_, existing = entry
msgs = list(existing) + [SystemMessage(content=text)]
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
self._store[key] = (time.monotonic(), capped)
class ContextualBufferProxy:
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
Returned by get_session_buffer() so callers can call
``proxy.append_system_message(text)`` without threading user_id/session_id
through every call site.
"""
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
self._buf = buf
self._user_id = user_id
self._session_id = session_id
def append_system_message(self, text: str) -> None:
self._buf.append_system_message(self._user_id, self._session_id, text)
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
session_buffer = _SessionBuffer()

View File

@@ -1,4 +1,4 @@
"""Single-agent runners for home and floating chat contexts."""
"""Single-agent runners for home and contextual chat contexts."""
from __future__ import annotations
@@ -7,7 +7,7 @@ import logging
import re
from datetime import date
from collections.abc import AsyncGenerator
from typing import Any, Literal
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
@@ -29,9 +29,6 @@ logger = logging.getLogger(__name__)
MAX_HISTORY_TURNS = 20
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
# Mapping of core-memory language values to natural-language names for prompts.
_LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish",
@@ -354,42 +351,24 @@ For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_t
{request_context}\
"""
_FLOATING_SYSTEM_PROMPT = """\
You are adiuvAI's floating executive assistant.{user_identity}
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
_CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions.
# How you work
- Use tools before answering anything factual. Never guess.
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant.
- Match the user's tone preference. Default to warm-but-direct.
- When the user asks to remember, forget, or update something, use memory tools.
Rules:
1. Base context (current view summary) is provided every turn. Treat it as ground truth for ids and names; never invent them.
2. ALL reads go through `get_page_details`. The legacy tools `list_projects`, `get_project`, `list_tasks`, `get_task`, `list_notes`, `get_note` are NOT available in this channel — do not attempt to call them. To find an entity by name, call `get_page_details({entityType: 'projects_all' | 'tasks_all' | 'timeline_all'})` to list, then `get_page_details({entityType: '<type>', entityId})` for the full snapshot.
3. When the user requests an action that creates or updates an entity:
- If the current view is a project and no project is specified, use the current project automatically.
- If the current view is the global Tasks / Projects / Timeline list and no project is specified, ASK before attaching to any project. Don't silently create orphan entities.
4. The current view can change mid-conversation (user navigates). When you see a system message "User navigated to ...", treat the new view as the active context. Prior turns remain visible but the active scope shifts.
5. Notes: you can read note bodies via `get_page_details({entityType:'note'})`. You CANNOT edit, summarize-to-replace, or append. Tell the user "note editing is coming in a later release" if asked.
6. Be concise. Default to 1-3 short paragraphs. Bullet lists fine. Don't restate the user's request.
7. Never expose ids in prose. Use names. Ids only travel through tool calls.
# Filter discipline
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
- The user's own name in the User profile block is for context only — it is NOT a default filter.
- When in doubt, omit `assignee` and return the global result.
# Output format
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
# Date filtering
# Date context
{date_context}
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Behavioral hints
{proactive_hints}
# Request context
{request_context}\
"""
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
@@ -466,19 +445,6 @@ Stay terse — your principal is a busy executive.
{request_context}\
"""
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
@@ -556,6 +522,55 @@ def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
# ── Contextual sidebar tools ──────────────────────────────────────────
@tool
async def get_page_details(
entity_type: str = "",
entity_id: str = "",
) -> str:
"""Fetch full details for the entity currently in view.
entity_type: one of 'project' | 'task' | 'note' | 'timeline_event' |
'tasks_all' | 'projects_all' | 'timeline_all'.
entity_id: UUID of the entity for singular entity views. Omit for list views.
The Electron drizzle-executor fulfils this op against local SQLite and
returns the row(s) as a JSON tool result.
"""
result = await execute_on_client(
action="get_page_details",
table=entity_type or "unknown",
data={"entityId": entity_id or None},
)
if not result:
return "No details found."
return str(result)
def _contextual_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return the tool palette for the contextual sidebar agent.
Read ops go through get_page_details only — legacy list_*/get_* tools
return shallow snapshots and cause the agent to under-answer (see
smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0). Writes are limited
to entity creation + task update; note edits are next-sprint.
"""
from app.agents.note_agent import create_note # noqa: PLC0415
from app.agents.task_agent import create_task, update_task # noqa: PLC0415
from app.agents.timeline_agent import create_timeline # noqa: PLC0415
return [
get_page_details,
create_task,
update_task,
create_note,
create_timeline,
*_memory_tools(user_id, trace_id),
]
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
@@ -658,70 +673,6 @@ def _normalize_tagged_list_lines(text: str, message: str) -> str:
return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip()
if value.startswith("/memories/"):
@@ -902,168 +853,6 @@ def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_agent_llm("classifier")
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
lf = get_langfuse()
_, classifier_prompt_obj = get_prompt_or_fallback(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
)
# Extract user/session from context for Langfuse attribution
_debug = context.get("_debug") if isinstance(context, dict) else None
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=model_for_agent("classifier"),
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
response = await llm.ainvoke(classifier_messages)
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
else:
response = await llm.ainvoke(classifier_messages)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
if not history:
return []
@@ -1392,25 +1181,6 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream(
user_id: str,
message: str,
@@ -1457,69 +1227,47 @@ async def run_home_stream(
yield "token", normalized
async def run_floating_stream(
async def run_contextual_stream(
user_id: str,
message: str,
context: dict[str, Any],
scope: "ContextualScope", # type: ignore[name-defined]
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the contextual agent for a single user turn.
Injects the rendered scope block into the system prompt and exposes
the contextual tool set.
Note-edit tools (propose_note_edit) are intentionally excluded.
*context contract*: callers MUST include ``context["_debug"]["session_id"]``
(a non-empty str) so that ``_session_id_from_context`` can extract it for
tracing and episode storage downstream. The WS handler in device_ws.py
satisfies this by always populating ``_debug`` before calling this function.
"""
from app.schemas.contextual import ContextualScope, render_scope_block # noqa: PLC0415
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
trace_id = _trace_id_from_context(prepared_context)
brief_mode: bool = bool(context.get("brief_mode"))
briefing_context_text: str = str(context.get("briefing_context") or "").strip()
system_prompt, langfuse_prompt = _build_system_prompt(
"contextual_system", _CONTEXTUAL_SYSTEM_PROMPT, prepared_context,
)
scope_block = render_scope_block(scope)
system_prompt = system_prompt + f"\n\n## Current view\n{scope_block}"
tools = _contextual_tools(user_id, trace_id)
if brief_mode and briefing_context_text:
# Stage 2: inject briefing as ground truth context.
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
# before compile_prompt sees the remaining standard variables.
template, langfuse_prompt = get_prompt_or_fallback(
"task_brief_followup_system",
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
)
system_prompt = compile_prompt(
template, langfuse_prompt,
date_context=_datetime_context_injection(prepared_context).strip(),
language_instruction=_language_instruction(prepared_context).strip(),
user_identity=_user_identity_injection(prepared_context).strip(),
relational_memory=_relational_memory_injection(prepared_context).strip(),
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
request_context=_request_context_block(prepared_context),
briefing_context=briefing_context_text,
)
else:
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
agent_name="contextual-agent",
tools=tools,
conversation_history=context.get("conversation_history"),
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
yield event
async def run_task_brief_research_stream(

View File

@@ -103,7 +103,6 @@ def get_llm(
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,

View File

@@ -6,7 +6,7 @@ import re
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
@@ -31,7 +31,7 @@ def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
class StreamFormatter:
@@ -47,14 +47,6 @@ class StreamFormatter:
started = False
async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token":
continue

View File

@@ -73,11 +73,9 @@ class WsFrameType(str, Enum):
device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
@@ -96,6 +94,10 @@ class WsFrameType(str, Enum):
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
# ── v8 contextual sidebar frame types ────────────────────────────
contextual_request = "contextual_request"
contextual_scope_update = "contextual_scope_update"
contextual_scope_ack = "contextual_scope_ack"
class WsToolCall(BaseModel):
@@ -161,13 +163,6 @@ class FormatPrefsModel(BaseModel):
now_iso: str = ""
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message."""
@@ -177,15 +172,6 @@ class WsHomeRequest(BaseModel):
format_prefs: FormatPrefsModel | None = None
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
format_prefs: FormatPrefsModel | None = None
class WsBriefRequest(BaseModel):
"""Client → Server: Request a plain-text brief (home or project)."""
@@ -221,22 +207,6 @@ class WsStreamEnd(BaseModel):
mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Config V2 ───────────────────────────────────────────────────

73
app/schemas/contextual.py Normal file
View File

@@ -0,0 +1,73 @@
"""Contextual sidebar scope schema and prompt block renderer.
ContextualScope mirrors the TypeScript ContextualScope type sent by the
Electron renderer when the user opens the side chat anchored to a specific
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
them to snake_case Python attributes automatically.
"""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
PageType = Literal[
"timeline",
"tasks",
"projects-list",
"project",
"note",
]
EntityType = Literal["project", "note", "task", "timeline_event"]
class ContextualScope(BaseModel):
"""Scope payload sent by the Electron renderer for contextual chat.
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
alias generator maps them to snake_case Python attrs.
"""
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
page: PageType
entity_type: Optional[EntityType] = None
entity_id: Optional[str] = None
entity_name: Optional[str] = None
project_id: Optional[str] = None
char_count: Optional[int] = None
counts: Optional[dict[str, int]] = None
filters: Optional[dict] = None
def render_scope_block(scope: ContextualScope) -> str:
"""Produce a single-paragraph human-readable summary of the current view
for injection into the contextual agent system prompt.
Never emits internal ids — only names. The LLM is told to use names in
prose; ids travel through tool calls.
"""
if scope.entity_type == "project":
c = scope.counts or {}
return (
f"User is viewing the project {scope.entity_name!r}. "
f"{c.get('tasks', 0)} tasks, "
f"{c.get('notes', 0)} notes, "
f"{c.get('milestones', 0)} milestones."
)
if scope.entity_type == "note":
return (
f"User is viewing the note {scope.entity_name!r} "
f"({scope.char_count or 0} characters)."
)
if scope.page == "tasks":
return "User is viewing the global Tasks list (all projects)."
if scope.page == "timeline":
return "User is viewing the global Timeline view."
if scope.page == "projects-list":
return "User is viewing the Projects list."
return f"User is on page {scope.page}."

View File

@@ -0,0 +1,52 @@
import pytest
from app.schemas.contextual import ContextualScope, render_scope_block
def test_render_project_scope():
scope = ContextualScope(
page="project",
entity_type="project",
entity_id="p1",
entity_name="Acme Q3 launch",
counts={"tasks": 12, "notes": 4, "milestones": 3},
)
block = render_scope_block(scope)
assert "Acme Q3 launch" in block
assert "12 tasks" in block
assert "4 notes" in block
assert "3 milestones" in block
assert "p1" not in block
def test_render_list_scope_no_entity():
scope = ContextualScope(page="tasks", entity_type=None)
block = render_scope_block(scope)
assert "tasks" in block.lower()
assert "None" not in block
def test_render_note_scope_includes_char_count():
scope = ContextualScope(
page="note",
entity_type="note",
entity_id="n1",
entity_name="Meeting 14 May",
project_id="p1",
char_count=4280,
)
block = render_scope_block(scope)
assert "Meeting 14 May" in block
assert "4280" in block or "4,280" in block
def test_parses_camelcase_payload_from_renderer():
payload = {
"page": "project",
"entityType": "project",
"entityId": "p1",
"entityName": "Acme",
"counts": {"tasks": 5, "notes": 1, "milestones": 2},
}
scope = ContextualScope.model_validate(payload)
assert scope.entity_id == "p1"
assert scope.entity_name == "Acme"

View File

@@ -0,0 +1,44 @@
"""Tests for contextual WS frame handlers.
These tests only exercise the new handler functions in device_ws.py and do
not depend on litellm or the full deep_agent import chain. They monkeypatch
run_contextual_stream so no LLM call is made.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
@pytest.mark.asyncio
async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch):
"""_handle_contextual_scope_update must:
- call append_system_message on the session buffer
- send a contextual_scope_ack back on the socket
- make no LLM call
"""
from app.api.routes import device_ws
ws = AsyncMock()
buffer = MagicMock()
buffer.append_system_message = MagicMock()
payload = {
"type": "contextual_scope_update",
"session_id": "s1",
"scope": {
"page": "project",
"entityType": "project",
"entityId": "p1",
"entityName": "Acme",
"counts": {"tasks": 1, "notes": 0, "milestones": 0},
},
}
monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer)
await device_ws._handle_contextual_scope_update(ws, "user1", payload)
ws.send_text.assert_awaited_once()
import json
sent = json.loads(ws.send_text.await_args.args[0])
assert sent["type"] == "contextual_scope_ack"
assert sent["session_id"] == "s1"
buffer.append_system_message.assert_called_once()

View File

@@ -12,11 +12,8 @@ from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import (
_build_system_prompt,
_datetime_context_injection,
_infer_floating_domain,
_normalize_tagged_list_lines,
_request_context_block,
run_floating,
run_floating_stream,
run_home,
)
@@ -75,57 +72,6 @@ async def test_run_home_uses_mocked_tool_result():
assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
# _run_single_agent_stream uses ainvoke (not astream); the final token is
# the second LLM response which echoes the tool result.
token_events = [e for e in events if e[0] == "token"]
assert token_events, "Expected at least one token event"
combined = "".join(str(e[1]) for e in token_events)
assert "Mock Task" in combined
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = (
"Certo!\n\n"
@@ -162,139 +108,6 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events
# ── _datetime_context_injection ────────────────────────────────────────────────
def _fp(tz: str, now_iso: str) -> dict:

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import pytest
from app.core.output_formatter import StreamFormatter
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
async def _stream(*events: tuple[str, object]):
@@ -36,29 +36,6 @@ async def test_stream_formatter_text_stream() -> None:
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2")
frames = await _collect(
formatter,
_stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary"
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_stream_formatter_ignores_unknown_events() -> None:
formatter = StreamFormatter(request_id="req-3")

View File

@@ -0,0 +1,85 @@
"""Tests for run_contextual_stream.
These tests monkeypatch _run_single_agent_stream (the actual internal runner)
rather than the plan's fictional _run_agent_loop, matching the real
deep_agent.py architecture.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.schemas.contextual import ContextualScope
@pytest.mark.asyncio
async def test_run_contextual_stream_includes_scope_block(monkeypatch):
"""run_contextual_stream must inject the scope block into the system prompt
and include get_page_details in the tool list while excluding note-edit tools."""
import app.core.deep_agent as deep_agent
captured = {}
async def fake_stream(
*,
user_id,
system_prompt,
message,
context,
agent_name="agent",
tools=None,
conversation_history=None,
**kwargs,
):
captured["sys"] = system_prompt
captured["tool_names"] = [getattr(t, "name", str(t)) for t in (tools or [])]
captured["agent_name"] = agent_name
# Async generator that yields nothing — still satisfies the protocol.
if False:
yield # pragma: no cover
monkeypatch.setattr(deep_agent, "_run_single_agent_stream", fake_stream)
scope = ContextualScope(
page="project",
entity_type="project",
entity_id="p1",
entity_name="Acme",
counts={"tasks": 1, "notes": 0, "milestones": 0},
)
context = {
"conversation_history": [],
"_debug": {"session_id": "s1"},
}
results = []
async for item in deep_agent.run_contextual_stream(
user_id="user1",
message="hi",
context=context,
scope=scope,
):
results.append(item)
assert "Acme" in captured["sys"], "scope block must appear in system prompt"
assert "Current view" in captured["sys"], "section header must be present"
names = captured["tool_names"]
assert "get_page_details" in names, "get_page_details tool must be included"
# Entity-create tools: at least one of these must be present.
assert any(n in names for n in ("create_task", "create_note", "update_task")), (
"at least one entity-create tool must be present"
)
assert "create_timeline" in names, "create_timeline tool must be included"
# Note edit tools must NOT be exposed.
assert "propose_note_edit" not in names, "propose_note_edit must be excluded"
# Legacy read tools must be excluded — they return shallow snapshots and
# cause the agent to under-answer (see trace 0b46841484ba7d024ed9f8d5ac8b1df0).
assert "list_projects" not in names, "list_projects must be excluded (legacy read)"
assert "get_project" not in names, "get_project must be excluded (legacy read)"
assert "list_tasks" not in names, "list_tasks must be excluded (legacy read)"
assert "get_task" not in names, "get_task must be excluded (legacy read)"
assert "list_notes" not in names, "list_notes must be excluded (legacy read)"
assert "get_note" not in names, "get_note must be excluded (legacy read)"

View File

@@ -4,12 +4,8 @@ import pytest
from pydantic import ValidationError
from app.schemas import (
WsDomain,
WsFrameType,
WsHomeRequest,
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
WsStreamEnd,
WsStreamStart,
WsStreamText,
@@ -22,11 +18,9 @@ from app.schemas import (
def test_v3_frame_types_exist():
v3_types = [
"home_request",
"floating_request",
"stream_start",
"stream_text",
"stream_end",
"floating_domain",
"data_request",
"data_response",
"mutation",
@@ -86,51 +80,6 @@ def test_home_request_requires_message():
WsHomeRequest.model_validate({"type": "home_request"})
# ── WsFloatingRequest ────────────────────────────────────────────────────
def test_floating_request_basic():
frame = WsFloatingRequest(
message="Summarise",
scope=WsFloatingScope(type="task", id="task-123"),
)
assert frame.type == WsFrameType.floating_request
assert frame.scope.type == "task"
assert frame.scope.id == "task-123"
def test_floating_request_scope_without_id():
frame = WsFloatingRequest(
message="Show all",
scope=WsFloatingScope(type="project"),
)
assert frame.scope.id is None
def test_floating_request_serializes():
frame = WsFloatingRequest(
message="Test",
scope=WsFloatingScope(type="note", id="n-1"),
)
data = frame.model_dump()
assert data["type"] == "floating_request"
assert data["scope"]["type"] == "note"
assert data["scope"]["id"] == "n-1"
def test_floating_request_invalid_scope_type():
with pytest.raises(ValidationError):
WsFloatingRequest(
message="X",
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
)
def test_floating_request_requires_scope():
with pytest.raises(ValidationError):
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
# ── WsStreamStart ─────────────────────────────────────────────────────
@@ -189,51 +138,3 @@ def test_stream_end_deserializes():
assert frame.request_id == "r3"
# ── WsFloatingDomain ─────────────────────────────────────────────────────
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain
assert frame.domain.type == "task"
def test_floating_domain_valid_domains():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_object_valid():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes():
d = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes():
raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -1,6 +1,6 @@
"""Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and floating_request frames,
Tests the device WS endpoint with home_request frames,
verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency.
@@ -34,7 +34,7 @@ def _override_db(db_session):
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
"""Receive frames until stream_end or max_frames."""
frames = []
for _ in range(max_frames):
raw = ws.receive_text()
@@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context):
yield "token", "Hello"
async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary"
# ── tests ─────────────────────────────────────────────────────────────────────
def test_home_request_produces_stream_frames(client):
@@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client):
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "floating_request",
"request_id": "p1",
"message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"},
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1"
def test_home_request_request_id_propagated(client):
"""request_id in home_request is echoed in all response frames."""
token = make_jwt("power", user_id=USER_ID)