refactor: replace orchestrator with LangGraph deep-agent supervisors

- Add app/core/deep_agent.py with Home and Floating supervisor graphs
  using LangGraph create_react_agent (hierarchical pattern)
- Strip ChatAgent classes from all 4 agent files, keep @tool functions
- Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream
- Update device_ws.py to use run_home_stream/run_floating_stream
- Rewrite chat.py REST route to use run_home
- Add update_core_memory tool to both supervisors
- Add langgraph>=0.3.0 to requirements.txt
- Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py
- Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas
- Update all affected tests to match new API
- Remove 6 deprecated test files for deleted modules
- Clean up stale docstrings referencing removed orchestrator
This commit is contained in:
2026-03-11 17:50:22 +01:00
parent 2de67213f8
commit cfc9d7a942
31 changed files with 723 additions and 3498 deletions

View File

@@ -1,4 +1,4 @@
"""Import all agent modules to trigger @registry.register decorators.""" """Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
from app.agents import timeline_agent, note_agent, project_agent, task_agent from app.agents import timeline_agent, note_agent, project_agent, task_agent

View File

@@ -1,31 +1,14 @@
"""Note agent — Markdown note management (list, get, create, update, delete).""" """Note agent — tool definitions for Markdown note CRUD."""
from __future__ import annotations from __future__ import annotations
import json
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry from app.core.llm import embed
from app.core.llm import embed, get_llm
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool @tool
async def list_notes(project_id: str = "") -> str: async def list_notes(project_id: str = "") -> str:
@@ -122,23 +105,4 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted." return f"Note {note_id} deleted."
@registry.register
class NoteAgent(ChatAgent):
def get_name(self) -> str:
return "note_agent"
def get_description(self) -> str:
return "Manages notes: list, get, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_notes, get_note, create_note, update_note, delete_note]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,33 +1,13 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).""" """Project agent — tool definitions for project lifecycle CRUD."""
from __future__ import annotations from __future__ import annotations
import json
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool @tool
async def list_projects( async def list_projects(
@@ -137,30 +117,4 @@ async def delete_project(project_id: str) -> str:
return f"Project {project_id} permanently deleted." return f"Project {project_id} permanently deleted."
@registry.register
class ProjectAgent(ChatAgent):
def get_name(self) -> str:
return "project_agent"
def get_description(self) -> str:
return "Manages projects: list, get, create, update, archive, delete"
def get_tools(self) -> list[Any]:
return [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,35 +1,14 @@
"""Task agent — full CRUD for tasks and task comments.""" """Task agent — tool definitions for task and task comment CRUD."""
from __future__ import annotations from __future__ import annotations
import json
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ──────────────────────────────────────────────────────── # ── Task tools ────────────────────────────────────────────────────────
@@ -220,35 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
return f"Comment {comment_id} deleted." return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
@registry.register
class TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
def get_tools(self) -> list[Any]:
return [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,30 +1,13 @@
"""Timeline agent — project milestone management (list, create, update, delete).""" """Timeline agent — tool definitions for project milestone CRUD."""
from __future__ import annotations from __future__ import annotations
import json
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool @tool
async def list_timelines(project_id: str = "") -> str: async def list_timelines(project_id: str = "") -> str:
@@ -106,23 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
return f"Timeline {timeline_id} deleted." return f"Timeline {timeline_id} deleted."
@registry.register
class TimelineAgent(ChatAgent):
def get_name(self) -> str:
return "timeline_agent"
def get_description(self) -> str:
return "Manages project timelines (milestones): list, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_timelines, create_timeline, update_timeline, delete_timeline]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -9,8 +9,10 @@ from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.orchestrator import orchestrate from app.core.deep_agent import run_home
from app.schemas import ChatRequest, UserProfile from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
from app.schemas import ChatRequest, ChatResponse, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"]) router = APIRouter(prefix="/chat", tags=["chat"])
@@ -20,10 +22,21 @@ async def chat(
body: ChatRequest, body: ChatRequest,
current_user: UserProfile = Depends(get_current_user), current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse: ) -> JSONResponse:
"""Route a chat message through the orchestrator. """Route a chat message through the Home deep agent (non-streaming)."""
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(current_user.id, body.message)
Returns ``ChatResponse`` for ``execution_mode='direct'``, context = {
or ``ExecutionPlan`` for ``execution_mode='plan'``. **body.context.model_dump(),
""" **memory_context,
result = await orchestrate(body) }
response_text = await run_home(
user_id=current_user.id,
message=body.message,
context=context,
db_session_factory=async_session,
)
result = ChatResponse(response=response_text)
return JSONResponse(content=result.model_dump()) return JSONResponse(content=result.model_dump())

View File

@@ -43,7 +43,7 @@ from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.core.agent_runner import trigger_pending_runs
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.orchestrator import orchestrate_v3_stream from app.core.deep_agent import run_home_stream, run_floating_stream
from app.core.output_formatter import HomeFormatter, FloatingFormatter from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.core.ws_context import clear_client_executor, set_client_executor from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session from app.db import async_session
@@ -204,9 +204,17 @@ async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result.""" """Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict: async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call payload["type"] = WsFrameType.tool_call
call_id = payload["id"]
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
await websocket.send_text(json.dumps(payload)) await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"]) future = device_manager.create_pending_call(user_id, call_id)
return await future result = await future
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
call_id, type(result).__name__,
list(result.keys()) if isinstance(result, dict) else "N/A")
if result is None:
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
return result
return _executor return _executor
@@ -233,21 +241,13 @@ async def _handle_home_request(
executor = await _make_ws_executor(websocket, user_id) executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
agent_holder: list = []
try: try:
token_stream = orchestrate_v3_stream( event_stream = run_home_stream(
user_id, message, context, agent_holder=agent_holder user_id, message, context, db_session_factory=async_session
) )
formatter = HomeFormatter(request_id=request_id, tool_results=[]) formatter = HomeFormatter(request_id=request_id)
async for ws_frame in formatter.format(token_stream): async for ws_frame in formatter.format(event_stream):
# Inject mutations from agent tool_results into stream_end
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
ws_frame.mutations = [ # type: ignore[union-attr]
{"action": r["action"], "table": r["table"], "data": r["data"]}
for r in getattr(agent_holder[0], "tool_results", [])
]
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr] if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc: except Exception as exc:
@@ -287,18 +287,13 @@ async def _handle_floating_request(
executor = await _make_ws_executor(websocket, user_id) executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
agent_holder: list = []
try: try:
token_stream = orchestrate_v3_stream( event_stream = run_floating_stream(
user_id, message, context, agent_holder=agent_holder user_id, message, context, scope=scope,
db_session_factory=async_session,
) )
formatter = FloatingFormatter(request_id=request_id) formatter = FloatingFormatter(request_id=request_id)
async for ws_frame in formatter.format(token_stream): async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
ws_frame.mutations = [ # type: ignore[union-attr]
{"action": r["action"], "table": r["table"], "data": r["data"]}
for r in getattr(agent_holder[0], "tool_results", [])
]
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr] if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]

View File

@@ -1,37 +0,0 @@
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from app.api.deps import get_current_user
from app.core.execution_plan import plan_cache
from app.schemas import ExecutionPlan, UserProfile
router = APIRouter(prefix="/plans", tags=["plans"])
@router.get("/playbook", response_model=list[ExecutionPlan])
async def list_playbooks(
current_user: UserProfile = Depends(get_current_user),
) -> list[ExecutionPlan]:
"""Return all cached execution plan playbooks for the authenticated user.
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
"""
return plan_cache.get_all_playbooks()
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
async def get_playbook(
plan_id: str,
current_user: UserProfile = Depends(get_current_user),
) -> ExecutionPlan:
"""Return a specific execution plan playbook by ID."""
plan = plan_cache.get_plan(plan_id)
if plan is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Plan not found: {plan_id}",
)
return plan

View File

@@ -1,217 +0,0 @@
"""Agent Registry — base classes and singleton registry for chat agents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any
class BaseAgent(ABC):
"""Common base for all agents."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
"""Override in subclasses to advertise capabilities."""
return []
class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
self.tool_results: list[dict] = []
@abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response."""
...
async def handle_stream(
self, query: str, context: dict[str, Any]
) -> AsyncGenerator[str, None]:
"""Streaming variant of handle().
Default: calls handle() and yields the full response as one chunk.
Override in subclasses for true token-level streaming via _tool_loop_stream.
"""
yield await self.handle(query, context)
@abstractmethod
def get_tools(self) -> list[Any]:
"""Return LangChain tool definitions available to this agent."""
...
async def _tool_loop(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> str:
"""Shared tool-calling loop.
Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the
final text response. Captures raw execute_on_client results in
``self.tool_results``.
"""
from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return str(response.content)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages)
return str(response.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
async def _tool_loop_stream(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> AsyncGenerator[str, None]:
"""Streaming variant of ``_tool_loop``.
Behaves identically for tool-calling iterations (uses ainvoke to parse
tool calls). For the final response — when the model produces no further
tool calls — switches to ``llm.astream()`` and yields text tokens.
Captures raw execute_on_client results in ``self.tool_results``.
"""
from langchain_core.messages import AIMessage, ToolMessage
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
collector: list[dict] = []
set_tool_result_collector(collector)
try:
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
if not response.tool_calls:
# Stream the final answer — don't keep the ainvoke result.
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
return
messages.append(response)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — stream a final answer without tools
async for chunk in llm.astream(messages):
if chunk.content:
yield str(chunk.content)
finally:
clear_tool_result_collector()
self.tool_results = collector
class AgentRegistry:
"""Singleton registry for ChatAgent subclasses."""
_instance: AgentRegistry | None = None
def __init__(self) -> None:
self._agents: dict[str, type[ChatAgent]] = {}
def __new__(cls) -> AgentRegistry:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._agents = {}
return cls._instance
# ── public API ───────────────────────────────────────────────────
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
"""Class decorator — registers an agent by its name."""
instance = agent_class()
name = instance.get_name()
self._agents[name] = agent_class
return agent_class
def get(self, name: str) -> ChatAgent:
"""Return a fresh instance of the named agent."""
cls = self._agents.get(name)
if cls is None:
raise KeyError(f"Agent not found: {name}")
return cls()
def list_agents(self) -> list[dict[str, str]]:
"""Return ``[{name, description}]`` for the orchestrator prompt."""
result: list[dict[str, str]] = []
for cls in self._agents.values():
inst = cls()
result.append(
{"name": inst.get_name(), "description": inst.get_description()}
)
return result
async def call_agent(
self, name: str, query: str, context: dict[str, Any]
) -> str:
"""Instantiate the named agent and call its ``handle`` method."""
agent = self.get(name)
return await agent.handle(query, context)
# Module-level singleton
registry = AgentRegistry()

View File

@@ -1,4 +1,4 @@
"""Agent run orchestrator. """Agent run manager.
Drives two agent types: Drives two agent types:

429
app/core/deep_agent.py Normal file
View File

@@ -0,0 +1,429 @@
"""Deep Agent — LangGraph hierarchical supervisors for home and floating modes.
Two supervisor graphs (both ``create_react_agent``):
* **HomeSupervisor** — gathers data from multiple domains, presents
structured overview with tool-result blocks.
* **FloatingSupervisor** — focused, scoped assistant for a single entity/domain.
Each supervisor delegates to four sub-agent tools, each a compiled
``create_react_agent`` wrapping the domain CRUD tools (task, project, note,
timeline). The sub-agents talk to Electron via ``execute_on_client``.
Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that
callers can sniff:
* ``("messages", (token, metadata))`` — text tokens for streaming
* ``("updates", ...)`` — tool call results for mutations
An ``update_core_memory`` tool is available to both supervisors for
persisting user preferences mid-conversation (MemGPT-style).
"""
from __future__ import annotations
import json
import logging
from typing import Any, AsyncGenerator
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from app.core.llm import get_llm
from app.core.ws_context import (
clear_tool_result_collector,
set_tool_result_collector,
)
logger = logging.getLogger(__name__)
# ── Sub-agent tool imports ────────────────────────────────────────────
from app.agents.task_agent import ( # noqa: E402
add_task_comment,
create_task,
delete_task,
delete_task_comment,
list_task_comments,
list_tasks,
list_tasks_due_today,
update_task,
)
from app.agents.note_agent import ( # noqa: E402
create_note,
delete_note,
get_note,
list_notes,
update_note,
)
from app.agents.project_agent import ( # noqa: E402
create_project,
delete_project,
get_project,
list_all_projects,
list_projects,
update_project,
)
from app.agents.timeline_agent import ( # noqa: E402
create_timeline,
delete_timeline,
list_timelines,
update_timeline,
)
# ── Sub-agent definitions ─────────────────────────────────────────────
_TASK_TOOLS = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note]
_PROJECT_TOOLS = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline]
def _build_subagent_tool(
name: str,
description: str,
system_prompt: str,
tools: list,
):
"""Build a compiled sub-agent graph and wrap it as a LangChain tool."""
subgraph = create_react_agent(
model=get_llm(),
tools=tools,
prompt=system_prompt,
name=name,
)
@tool(name=name, description=description)
async def _run(query: str) -> str:
result = await subgraph.ainvoke(
{"messages": [HumanMessage(content=query)]}
)
messages = result["messages"]
# Return the last AI message content
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
return str(msg.content)
return "No response from sub-agent."
return _run
def _make_subagent_tools() -> list:
"""Create the four sub-agent tools for the supervisor."""
return [
_build_subagent_tool(
name="task_agent",
description=(
"Manages tasks and comments: list, create, update, delete, "
"due-today, comments. Delegate task-related queries here."
),
system_prompt=(
"You are a task management assistant. You create, update, list, "
"and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds\n"
" - assignees is a JSON-encoded array of strings\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
),
tools=_TASK_TOOLS,
),
_build_subagent_tool(
name="note_agent",
description=(
"Manages notes: list, get, create, update, delete. "
"Delegate note-related queries here."
),
system_prompt=(
"You are a note-taking assistant. You help users create, retrieve, "
"update, and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - When updating, call get_note first if you need to read existing "
"content before appending or replacing sections\n"
" - Do not fabricate note content."
),
tools=_NOTE_TOOLS,
),
_build_subagent_tool(
name="project_agent",
description=(
"Manages projects: list, get, create, update, archive, delete. "
"Delegate project-related queries here."
),
system_prompt=(
"You are a project management assistant. You help users create, "
"find, update, and archive projects.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - Prefer archiving over deletion\n"
" - ai_summary is populated only when the user asks for a summary."
),
tools=_PROJECT_TOOLS,
),
_build_subagent_tool(
name="timeline_agent",
description=(
"Manages project timelines (milestones): list, create, update, "
"delete. Delegate timeline/milestone queries here."
),
system_prompt=(
"You are a project timeline assistant. Timelines are milestone "
"dates that track progress on a project.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create\n"
" - date is a Unix timestamp in milliseconds\n"
" - For update_timeline, use -1 for integer fields you do not "
"want to change."
),
tools=_TIMELINE_TOOLS,
),
]
# ── Update core memory tool ──────────────────────────────────────────
def _make_update_core_memory_tool(user_id: str, db_session_factory):
"""Create a tool that persists a key/value preference in core memory."""
@tool
async def update_core_memory(key: str, value: str) -> str:
"""Save a user preference or fact to long-term core memory.
key: short label for the memory (e.g. 'preferred_language', 'timezone')
value: the value to remember
Use this when the user states a preference or fact worth remembering.
"""
from app.core.memory_middleware import MemoryMiddleware
async with db_session_factory() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)
return f"Remembered: {key} = {value}"
return update_core_memory
# ── System prompts ────────────────────────────────────────────────────
_HOME_SYSTEM = (
"You are Adiuva, a smart workspace assistant on the Home dashboard.\n"
"Your job is to help the user by gathering data from their workspace and "
"presenting a comprehensive overview.\n\n"
"You have sub-agent tools (task_agent, note_agent, project_agent, "
"timeline_agent) that can query and mutate workspace data. Delegate to "
"the appropriate sub-agent(s) based on the user's request. You can call "
"multiple sub-agents if needed.\n\n"
"You also have an update_core_memory tool — use it when the user states "
"a preference or important fact worth remembering long-term.\n\n"
"After gathering data, synthesize a clear, helpful response for the user.\n\n"
"Memory context:\n{memory_context}"
)
_FLOATING_SYSTEM = (
"You are Adiuva, a focused workspace assistant in the floating panel.\n"
"The user is currently working in the '{scope_type}' section"
"{scope_detail}.\n\n"
"You have sub-agent tools (task_agent, note_agent, project_agent, "
"timeline_agent) that can query and mutate workspace data. Focus your "
"help on the user's current scope, but you can use other sub-agents "
"if the request requires it.\n\n"
"You also have an update_core_memory tool — use it when the user states "
"a preference or important fact worth remembering long-term.\n\n"
"Provide direct, conversational responses.\n\n"
"Memory context:\n{memory_context}"
)
def _format_memory_context(memory: dict[str, Any]) -> str:
"""Format the memory dict into a readable string for the system prompt."""
if not memory:
return "(no memory available)"
parts = []
if memory.get("core_memory"):
parts.append("Preferences: " + json.dumps(memory["core_memory"]))
if memory.get("associative_memory"):
parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3]))
if memory.get("episodic_memory"):
parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3]))
if memory.get("proactive_hints"):
parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3]))
return "\n".join(parts) if parts else "(no memory available)"
# ── Graph builders ────────────────────────────────────────────────────
def build_home_graph(
user_id: str,
memory_context: dict[str, Any],
db_session_factory,
):
"""Build the Home supervisor graph."""
subagent_tools = _make_subagent_tools()
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
all_tools = subagent_tools + [memory_tool]
prompt = _HOME_SYSTEM.format(
memory_context=_format_memory_context(memory_context),
)
return create_react_agent(
model=get_llm(),
tools=all_tools,
prompt=prompt,
name="home_supervisor",
)
def build_floating_graph(
user_id: str,
memory_context: dict[str, Any],
scope: dict[str, Any],
db_session_factory,
):
"""Build the Floating supervisor graph."""
subagent_tools = _make_subagent_tools()
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
all_tools = subagent_tools + [memory_tool]
scope_type = scope.get("type", "general")
scope_id = scope.get("id")
scope_detail = f" (id: {scope_id})" if scope_id else ""
prompt = _FLOATING_SYSTEM.format(
scope_type=scope_type,
scope_detail=scope_detail,
memory_context=_format_memory_context(memory_context),
)
return create_react_agent(
model=get_llm(),
tools=all_tools,
prompt=prompt,
name="floating_supervisor",
)
# ── Stream event type ────────────────────────────────────────────────
# Events yielded by run_*_stream:
# ("token", str) — text token for streaming
# ("tool_start", dict) — {"name": "task_agent", "args": {...}}
# ("tool_end", dict) — {"name": "task_agent", "result": "..."}
# ── Stream runners ────────────────────────────────────────────────────
async def _run_graph_stream(
graph,
message: str,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run a supervisor graph with streaming, yielding event tuples.
Uses ``stream_mode=["messages", "updates"]`` to get both token-level
streaming and update events for tool calls.
"""
inputs = {"messages": [HumanMessage(content=message)]}
collector: list[dict] = []
set_tool_result_collector(collector)
try:
async for stream_mode, chunk in graph.astream(
inputs,
stream_mode=["messages", "updates"],
):
if stream_mode == "messages":
msg, metadata = chunk
# Only yield tokens from the supervisor's final response
# (not from sub-agent internal LLM calls)
if (
isinstance(msg, AIMessageChunk)
and msg.content
and not msg.tool_calls
and metadata.get("langgraph_node") == "agent"
):
yield ("token", str(msg.content))
elif stream_mode == "updates":
# Updates is a dict of {node_name: state_update}
if not isinstance(chunk, dict):
continue
for node_name, state_update in chunk.items():
if node_name != "tools":
continue
# Tool node executed — extract tool call results
tool_messages = state_update.get("messages", [])
for tool_msg in tool_messages:
if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"):
yield (
"tool_end",
{"name": tool_msg.name, "result": str(tool_msg.content)},
)
finally:
clear_tool_result_collector()
# Yield the collected mutations so callers can attach them to stream_end
yield ("mutations", collector)
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
db_session_factory,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the Home supervisor and yield streaming events."""
graph = build_home_graph(user_id, context, db_session_factory)
async for event in _run_graph_stream(graph, message):
yield event
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
scope: dict[str, Any],
db_session_factory,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the Floating supervisor and yield streaming events."""
graph = build_floating_graph(user_id, context, scope, db_session_factory)
async for event in _run_graph_stream(graph, message):
yield event
async def run_home(
user_id: str,
message: str,
context: dict[str, Any],
db_session_factory,
) -> str:
"""Run the Home supervisor (non-streaming) and return full response text."""
graph = build_home_graph(user_id, context, db_session_factory)
result = await graph.ainvoke(
{"messages": [HumanMessage(content=message)]}
)
messages = result["messages"]
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
return str(msg.content)
return ""

View File

@@ -1,222 +0,0 @@
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
from __future__ import annotations
from collections import OrderedDict
from typing import Any
from app.schemas import ExecutionPlan, PlanStep
# ── Prompt Template Registry ──────────────────────────────────────────
class PromptTemplateRegistry:
"""Server-side store mapping template IDs to prompt text.
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
The actual prompt text is resolved here on the server, keeping prompt IP
out of API responses.
"""
def __init__(self) -> None:
self._templates: dict[str, str] = {}
def register(self, template_id: str, prompt_text: str) -> None:
self._templates[template_id] = prompt_text
def get(self, template_id: str) -> str:
"""Resolve a template ID to its prompt text.
Raises ``KeyError`` if the template is not registered.
"""
text = self._templates.get(template_id)
if text is None:
raise KeyError(f"Template not found: {template_id!r}")
return text
def has(self, template_id: str) -> bool:
return template_id in self._templates
def list_ids(self) -> list[str]:
"""Return all registered template IDs (never the text)."""
return list(self._templates.keys())
# ── Execution Plan Builder ────────────────────────────────────────────
class ExecutionPlanBuilder:
"""Fluent builder for ``ExecutionPlan`` objects.
Example::
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
.add_data_step("create_record", data_from_step=0)
.build()
)
"""
def __init__(self, agent: str) -> None:
self._agent = agent
self._steps: list[PlanStep] = []
# ── step adders ──────────────────────────────────────────────────
def add_step(
self, action: str, params: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append a generic action step with optional parameters."""
self._steps.append(PlanStep(action=action, variables=params))
return self
def add_llm_step(
self, template_id: str, variables: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append an LLM step referencing a server-side template by ID."""
self._steps.append(
PlanStep(action="llm", prompt_template=template_id, variables=variables)
)
return self
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
"""Append a step whose input comes from the output of an earlier step."""
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
return self
# ── build ────────────────────────────────────────────────────────
def build(self) -> ExecutionPlan:
"""Validate step references and return the ``ExecutionPlan``.
Raises ``ValueError`` if any ``data_from_step`` references a
non-existent or future step index.
"""
for i, step in enumerate(self._steps):
if step.data_from_step is not None:
if not (0 <= step.data_from_step < i):
raise ValueError(
f"Step {i}: data_from_step={step.data_from_step} must "
f"reference a preceding step index in range 0..{i - 1}"
)
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
class PlanCache:
"""In-memory LRU cache for ``ExecutionPlan`` objects.
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
The cache also serves as a runtime memoisation layer so that repeated
identical intent classifications can skip re-building the plan.
"""
def __init__(self, maxsize: int = 1000) -> None:
self._maxsize = maxsize
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
if key in self._cache:
del self._cache[key] # remove so re-insertion places it at the end
elif len(self._cache) >= self._maxsize:
self._cache.popitem(last=False) # evict least-recently-used
self._cache[key] = plan
def get_plan(self, key: str) -> ExecutionPlan | None:
"""Return the cached plan for *key*, or ``None`` if not present.
Accessing a plan marks it as most-recently used.
"""
if key not in self._cache:
return None
self._cache.move_to_end(key)
return self._cache[key]
def get_all_playbooks(self) -> list[ExecutionPlan]:
"""Return all cached plans (most-recently used last)."""
return list(self._cache.values())
# ── Module-level singletons ───────────────────────────────────────────
template_registry = PromptTemplateRegistry()
plan_cache = PlanCache()
def _register_builtin_templates() -> None:
"""Register the built-in server-side prompt templates.
These strings never leave the server. Clients only receive the IDs.
"""
_tpls: dict[str, str] = {
"tpl_task_agent_default": (
"You are a task management assistant. Help the user create, update, "
"list, and track tasks. Use correct status values (todo, in_progress, "
"done) and priority values (high, medium, low) from the workspace model."
),
"tpl_timeline_agent_default": (
"You are a project timeline assistant. Help the user create and manage "
"milestone timelines on their projects. Every timeline requires a "
"project_id and a date expressed as a Unix timestamp in milliseconds."
),
"tpl_project_agent_default": (
"You are a project management assistant. Help the user create, find, "
"update, and archive projects. Projects have a name, an optional client, "
"and a status of either active or archived."
),
"tpl_note_agent_default": (
"You are a note-taking assistant. Help the user create, retrieve, update, "
"and delete Markdown notes. Notes can optionally be linked to a project."
),
"tpl_task_extract_from_project": (
"Extract all actionable tasks from the provided project context. "
"Return a structured list of tasks, each with a title, inferred priority "
"(high, medium, or low), suggested status (todo), and a due_date in "
"milliseconds where a deadline can be inferred."
),
"tpl_note_weekly_summary": (
"Generate a weekly project summary note from the provided workspace data. "
"Include: tasks completed this week, tasks due soon, active projects, "
"and upcoming timelines. Format the output as clean Markdown."
),
}
for tid, text in _tpls.items():
template_registry.register(tid, text)
def _load_playbooks() -> None:
"""Pre-build and cache the built-in playbooks."""
playbooks: list[tuple[str, ExecutionPlan]] = [
(
"create_tasks_from_project",
ExecutionPlanBuilder("project_agent")
.add_llm_step(
"tpl_task_extract_from_project",
{"source": "project_context"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
(
"generate_weekly_note",
ExecutionPlanBuilder("note_agent")
.add_llm_step(
"tpl_note_weekly_summary",
{"period": "last_7_days"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
]
for key, plan in playbooks:
plan_cache.cache_plan(key, plan)
# Initialise on module load
_register_builtin_templates()
_load_playbooks()

View File

@@ -1,6 +1,6 @@
"""LLM factory — centralised model instantiation via LiteLLM. """LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()`` Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_: <https://docs.litellm.ai/docs/providers>`_:

View File

@@ -43,7 +43,7 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware: class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after.""" """Enrich agent context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self._db = db self._db = db
@@ -51,7 +51,7 @@ class MemoryMiddleware:
# ── Public API ──────────────────────────────────────────────────────────── # ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]: async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call. """Build memory context dict to inject into the agent before LLM call.
Returns a dict with keys: Returns a dict with keys:
core_memory — {key: plaintext_value, ...} core_memory — {key: plaintext_value, ...}

View File

@@ -1,210 +0,0 @@
"""Orchestrator — LLM-based intent router and agent pipeline."""
from __future__ import annotations
import json
from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
_FALLBACK_AGENT = "task_agent"
_CLASSIFY_SYSTEM = (
"You are an intent classifier. Given the user message and context, decide "
"which agent to route to.\n"
"Available agents: {agents}\n"
"Respond with just the agent name, nothing else."
)
_SYNTHESIZE_HUMAN = (
"Combine the following agent results into one coherent response.\n\n"
"Agent results:\n{results}\n\n"
"Original message: {message}"
)
def _make_llm():
return get_router_llm()
async def classify_intent(
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> str:
"""Use gpt-4o-mini to classify intent and return the matching agent name.
Falls back to ``task_agent`` when the registry is empty or the model
returns a name that is not registered.
"""
agents = reg.list_agents()
if not agents:
return _FALLBACK_AGENT
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
# Truncate context to keep the classification prompt short
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
llm = _make_llm()
response = await llm.ainvoke(
[SystemMessage(content=system), HumanMessage(content=human)]
)
agent_name = str(response.content).strip().lower()
known = {a["name"] for a in agents}
return agent_name if agent_name in known else _FALLBACK_AGENT
async def route_single(
agent_name: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
response_text = await reg.call_agent(agent_name, message, context)
return ChatResponse(response=response_text)
async def route_pipeline(
agent_names: list[str],
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Execute agents sequentially; each agent receives previous results in context.
A final LLM synthesis call merges all results into one coherent response.
"""
previous_results: list[str] = []
for agent_name in agent_names:
ctx = {**context, "previous_results": list(previous_results)}
result = await reg.call_agent(agent_name, message, ctx)
previous_results.append(result)
results_str = "\n\n".join(
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
)
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
llm = _make_llm()
synthesis = await llm.ainvoke([HumanMessage(content=human)])
return ChatResponse(response=str(synthesis.content))
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
"""Build an ``ExecutionPlan`` for the resolved agent.
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
If a default template exists for the agent, an LLM step is emitted;
otherwise a plain ``handle`` action step is used.
"""
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
template_id = f"tpl_{agent_name}_default"
builder = ExecutionPlanBuilder(agent_name)
if template_registry.has(template_id):
builder.add_llm_step(template_id, {"message": message})
else:
builder.add_step("handle", {"message": message})
return builder.build()
async def orchestrate(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> ChatResponse | ExecutionPlan:
"""Main orchestration entry point.
* Classifies the user's intent to select an agent.
* ``execution_mode == 'direct'``: routes to the agent and returns a
``ChatResponse``.
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
resolved agent and a template-ID-only step (prompt IP stays server-side).
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
if request.execution_mode == "direct":
return await route_single(agent_name, request.message, context, reg)
# plan mode — return plan, do not execute
return _build_plan(agent_name, request.message)
async def orchestrate_v3(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
) -> tuple[str, ChatAgent]:
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
Classifies intent and instantiates the matching agent. The caller is responsible
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
return agent_name, reg.get(agent_name)
async def orchestrate_v3_stream(
user_id: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry | None = None,
agent_holder: list | None = None,
) -> AsyncGenerator[tuple[str, str], None]:
"""v3 streaming orchestration — yields (agent_name, token) pairs.
The first yield always carries the agent_name with an empty token so that
callers (e.g. FloatingFormatter) can detect the routing domain before any text
tokens arrive.
If *agent_holder* is provided (a list), the agent instance is appended so
callers can access ``agent.tool_results`` after the stream completes.
"""
if reg is None:
reg = _default_registry
agent_name = await classify_intent(message, context, reg)
agent = reg.get(agent_name)
if agent_holder is not None:
agent_holder.append(agent)
yield agent_name, "" # domain signal — no token yet
async for token in agent.handle_stream(message, context):
yield agent_name, token
async def orchestrate_stream(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> AsyncGenerator[str, None]:
"""Streaming orchestration — yields plain text chunks only.
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
wrapping each chunk in a ``text_chunk`` frame and sending the final
``final`` frame once the generator is exhausted.
Agents do not yet support token-level streaming; the full response is
fetched first (which may involve multiple WS round-trips for tool calls),
then emitted in fixed-size chunks.
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
response_text = await reg.call_agent(agent_name, request.message, context)
chunk_size = 50
for i in range(0, len(response_text), chunk_size):
yield response_text[i : i + chunk_size]

View File

@@ -1,12 +1,23 @@
"""Output Formatter — transforms orchestrator token streams into WS frame sequences. """Output Formatter — transforms deep-agent event streams into WS frame sequences.
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
FloatingFormatter: produces floating_domain, stream_text, stream_end * ``("token", str)`` — supervisor text token
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
HomeFormatter:
* Sniffs ``tool_end`` events → emits ``WsStreamBlock`` (entity_ref with raw data)
* Streams text tokens → emits ``WsStreamText``
* Attaches mutations → injects into ``WsStreamEnd``
FloatingFormatter:
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
* Streams text tokens → emits ``WsStreamText``
* Attaches mutations → injects into ``WsStreamEnd``
""" """
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
@@ -21,10 +32,7 @@ from app.schemas import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron) # Map sub-agent tool name → floating domain / entity type
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
# Map agent name → floating domain
_AGENT_DOMAIN: dict[str, str] = { _AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks", "task_agent": "tasks",
"timeline_agent": "timelines", "timeline_agent": "timelines",
@@ -36,180 +44,74 @@ WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatin
class HomeFormatter: class HomeFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames. """Consumes a deep-agent event stream and yields WS frames for the Home view.
The LLM is expected to output a newline-delimited sequence of JSON objects, ``tool_end`` events from sub-agents are emitted as ``WsStreamBlock``
each with a ``type`` field: (entity_ref) so the client can render structured data. Text tokens are
- ``text`` → yields WsStreamText immediately (word-by-word) forwarded as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``.
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
Invalid or unknown blocks are logged and skipped — stream never crashes.
"""
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
self.request_id = request_id
self.tool_results = tool_results
async def format(
self,
token_stream: AsyncGenerator[tuple[str, str], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
buffer = ""
async for _agent_name, token in token_stream:
if not token:
continue
buffer += token
# Flush any complete JSON objects from the buffer
async for frame in self._flush_complete_objects(buffer):
buffer = "" # reset after flush
yield frame
break # only one flush per iteration; rest accumulates
# Flush any remaining content
if buffer.strip():
async for frame in self._flush_complete_objects(buffer, final=True):
yield frame
yield WsStreamEnd(request_id=self.request_id)
async def _flush_complete_objects(
self, text: str, final: bool = False
) -> AsyncGenerator[WsFrame, None]:
"""Try to parse and yield all complete JSON objects from *text*.
Yields nothing if text is incomplete JSON (unless *final* is True,
in which case remaining text is emitted as plain stream_text).
"""
remaining = text.strip()
while remaining:
# Fast path: plain text (not JSON)
if not remaining.startswith("{"):
# Yield as plain text chunk
newline_idx = remaining.find("\n")
if newline_idx == -1:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
else:
return # accumulate more
else:
line = remaining[:newline_idx].strip()
remaining = remaining[newline_idx + 1:].strip()
if line:
yield WsStreamText(request_id=self.request_id, chunk=line)
continue
# Try to decode a JSON object
try:
obj, end_idx = _try_parse_json(remaining)
except ValueError:
if final:
# Emit as raw text if we can't parse
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return
if obj is None:
if final:
yield WsStreamText(request_id=self.request_id, chunk=remaining)
remaining = ""
return # incomplete — need more tokens
remaining = remaining[end_idx:].strip()
block_type = obj.get("type")
frame = self._dispatch_block(obj, block_type)
if frame is not None:
yield frame
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
if block_type == "text":
content = obj.get("content", "")
if content:
return WsStreamText(request_id=self.request_id, chunk=str(content))
return None
if block_type == "chart":
chart_type = obj.get("chartType")
if chart_type not in _VALID_CHART_TYPES:
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
return None
if not isinstance(obj.get("data"), list):
logger.warning("HomeFormatter: chart missing data array — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="chart",
data=obj,
)
if block_type == "entity_ref":
entity = obj.get("entity")
resolved = self._resolve_entity(entity)
if resolved is None:
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="entity_ref",
data={"entity": entity, "items": resolved},
)
if block_type == "table":
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
logger.warning("HomeFormatter: table missing headers/rows — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="table",
data=obj,
)
if block_type == "timeline":
if not isinstance(obj.get("timelines"), list):
logger.warning("HomeFormatter: timeline missing timelines — skipping")
return None
return WsStreamBlock(
request_id=self.request_id,
block_type="timeline",
data=obj,
)
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
return None
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
"""Find matching items in tool_results by entity type."""
if not entity:
return None
matches = [r for r in self.tool_results if r.get("entity") == entity]
return matches if matches else None
class FloatingFormatter:
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
Emits floating_domain immediately (from agent_name), then streams all tokens
as plain stream_text — no block parsing for floating context.
""" """
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
self.request_id = request_id self.request_id = request_id
self._mutations: list[dict] = []
async def format( async def format(
self, self,
token_stream: AsyncGenerator[tuple[str, str], None], event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
async for event_type, data in event_stream:
if event_type == "token":
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
elif event_type == "tool_end":
# Sub-agent finished — emit its result as an entity_ref block
name = data.get("name", "")
entity = _AGENT_DOMAIN.get(name)
if entity:
yield WsStreamBlock(
request_id=self.request_id,
block_type="entity_ref",
data={"entity": entity, "result": data.get("result", "")},
)
elif event_type == "mutations":
self._mutations = data or []
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)
class FloatingFormatter:
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
``task_agent`` → ``"tasks"``), then streams text tokens as plain
``WsStreamText``. No block parsing for floating context.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._mutations: list[dict] = []
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]: ) -> AsyncGenerator[WsFrame, None]:
domain_sent = False domain_sent = False
async for agent_name, token in token_stream: async for event_type, data in event_stream:
if not domain_sent: if event_type == "tool_end" and not domain_sent:
domain = _AGENT_DOMAIN.get(agent_name, "tasks") # Sniff domain from the first sub-agent that completes
name = data.get("name", "")
domain = _AGENT_DOMAIN.get(name, "tasks")
yield WsFloatingDomain( yield WsFloatingDomain(
request_id=self.request_id, request_id=self.request_id,
domain=domain, # type: ignore[arg-type] domain=domain, # type: ignore[arg-type]
@@ -217,28 +119,33 @@ class FloatingFormatter:
yield WsStreamStart(request_id=self.request_id) yield WsStreamStart(request_id=self.request_id)
domain_sent = True domain_sent = True
if token: elif event_type == "token":
yield WsStreamText(request_id=self.request_id, chunk=token) if not domain_sent:
# First token arrived before any tool_end — default domain
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
yield WsStreamEnd(request_id=self.request_id) elif event_type == "mutations":
self._mutations = data or []
# If no events triggered domain_sent (edge case), still emit structure
if not domain_sent:
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
# ── helpers ─────────────────────────────────────────────────────────────────── yield WsStreamEnd(
request_id=self.request_id,
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]: mutations=[
"""Attempt to parse the first complete JSON object from *text*. {"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the ],
object is incomplete, and raises ``ValueError`` when text is not JSON. )
"""
decoder = json.JSONDecoder()
try:
obj, end_idx = decoder.raw_decode(text)
if not isinstance(obj, dict):
raise ValueError("Expected JSON object")
return obj, end_idx
except json.JSONDecodeError as exc:
# Incomplete JSON — need more tokens
if "Unterminated" in str(exc) or exc.pos == len(text):
return None, 0
raise ValueError(str(exc)) from exc

View File

@@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
from __future__ import annotations from __future__ import annotations
import logging
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
from uuid import uuid4 from uuid import uuid4
logger = logging.getLogger(__name__)
# Holds the execute callback for the current WS session. # Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after. # Set by the chat WS handler before the deep agent runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar( _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor" "_client_executor"
) )
# Optional collector that captures raw execute_on_client results. # Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results. # Set by the deep agent tool loop to capture CRUD mutations.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar( _tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None "_tool_result_collector", default=None
) )
@@ -81,7 +84,12 @@ async def execute_on_client(
if limit is not None: if limit is not None:
payload["limit"] = limit payload["limit"] = limit
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
result = await callback(payload) result = await callback(payload)
if result is None:
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
else:
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
collector = _tool_result_collector.get(None) collector = _tool_result_collector.get(None)
if collector is not None: if collector is not None:
collector.append({ collector.append({

View File

@@ -18,10 +18,7 @@ from app.config.settings import settings
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup: initialise DB connection pool and agent registry # Startup: initialise DB connection pool
from app.core.agent_registry import registry # noqa: F401 — triggers module load
import app.agents # noqa: F401 — triggers @registry.register decorators
yield yield
# Shutdown: dispose SQLAlchemy connection pool # Shutdown: dispose SQLAlchemy connection pool
@@ -51,11 +48,10 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware) app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware) app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1") app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1")
app.include_router(plans.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1") app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1") app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1") app.include_router(backup.router, prefix="/api/v1")

View File

@@ -41,41 +41,13 @@ class ChatContext(BaseModel):
conversation_history: list[dict[str, Any]] = Field(default_factory=list) conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class PlanAction(BaseModel):
type: Literal[
"create_record",
"update_record",
"delete_record",
"index_document",
"send_notification",
]
table: str | None = None
data: dict[str, Any] | None = None
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
message: str message: str
context: ChatContext = Field(default_factory=ChatContext) context: ChatContext = Field(default_factory=ChatContext)
execution_mode: Literal["direct", "plan"] = "direct"
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
response: str response: str
actions: list[PlanAction] = Field(default_factory=list)
# ── Execution Plans ──────────────────────────────────────────────────
class PlanStep(BaseModel):
action: str
prompt_template: str | None = None
variables: dict[str, Any] | None = None
data_from_step: int | None = None
class ExecutionPlan(BaseModel):
agent: str
steps: list[PlanStep] = Field(default_factory=list)
# ── Backup ─────────────────────────────────────────────────────────── # ── Backup ───────────────────────────────────────────────────────────

View File

@@ -4,6 +4,7 @@ gunicorn>=22.0.0
langchain>=0.3.0 langchain>=0.3.0
langchain-openai>=0.3.0 langchain-openai>=0.3.0
langchain-litellm>=0.1.0 langchain-litellm>=0.1.0
langgraph>=0.3.0
litellm>=1.50.0 litellm>=1.50.0
pydantic>=2.10.0 pydantic>=2.10.0
pydantic-settings>=2.7.0 pydantic-settings>=2.7.0

View File

@@ -1,214 +0,0 @@
"""Unit tests for the agent registry, base classes, and tool loop."""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
# ── Helpers ──────────────────────────────────────────────────────────
class _StubAgent(ChatAgent):
"""Minimal concrete agent for testing."""
def get_name(self) -> str:
return "stub"
def get_description(self) -> str:
return "A stub agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"echo: {query}"
class _AnotherAgent(ChatAgent):
def get_name(self) -> str:
return "another"
def get_description(self) -> str:
return "Another stub"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "another"
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
return AgentRegistry()
# ── Tests ────────────────────────────────────────────────────────────
class TestRegisterAndGet:
def test_register_decorator(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agent = reg.get("stub")
assert isinstance(agent, _StubAgent)
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError, match="not found"):
reg.get("nonexistent")
def test_register_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
assert reg.get("stub").get_name() == "stub"
assert reg.get("another").get_name() == "another"
class TestListAgents:
def test_empty(self, reg: AgentRegistry) -> None:
assert reg.list_agents() == []
def test_list_after_register(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agents = reg.list_agents()
assert len(agents) == 1
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
def test_list_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
names = {a["name"] for a in reg.list_agents()}
assert names == {"stub", "another"}
class TestCallAgent:
@pytest.mark.asyncio
async def test_call_agent(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
result = await reg.call_agent("stub", "hello", {})
assert result == "echo: hello"
@pytest.mark.asyncio
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await reg.call_agent("nope", "hi", {})
class TestSingleton:
def test_singleton_identity(self) -> None:
a = AgentRegistry()
b = AgentRegistry()
assert a is b
class TestToolLoop:
@pytest.mark.asyncio
async def test_no_tool_calls(self) -> None:
"""When the LLM responds without tool calls, return content directly."""
agent = _StubAgent()
ai_msg = MagicMock()
ai_msg.content = "final answer"
ai_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(return_value=ai_msg)
result = await agent._tool_loop(llm, [], [])
assert result == "final answer"
@pytest.mark.asyncio
async def test_tool_call_then_answer(self) -> None:
"""LLM requests one tool call, gets result, then answers."""
agent = _StubAgent()
# First response: tool call
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
]
# Second response: final answer
final_msg = MagicMock()
final_msg.content = "done"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
# Mock tool
tool = AsyncMock()
tool.name = "my_tool"
tool.ainvoke = AsyncMock(return_value="tool_result")
result = await agent._tool_loop(llm, [], [tool])
assert result == "done"
tool.ainvoke.assert_called_once_with({"x": 1})
@pytest.mark.asyncio
async def test_unknown_tool_handled(self) -> None:
"""Unknown tool names produce an error message instead of crashing."""
agent = _StubAgent()
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "missing", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "recovered"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [], [])
assert result == "recovered"
@pytest.mark.asyncio
async def test_max_iter_reached(self) -> None:
"""When max iterations are exhausted, a final no-tools call is made."""
agent = _StubAgent()
# Every response requests a tool call
loop_msg = MagicMock()
loop_msg.content = ""
loop_msg.tool_calls = [
{"id": "call_x", "name": "t", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "gave up"
final_msg.tool_calls = []
tool = AsyncMock()
tool.name = "t"
tool.ainvoke = AsyncMock(return_value="ok")
llm_with_tools = AsyncMock()
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm.ainvoke = AsyncMock(return_value=final_msg)
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
assert result == "gave up"
assert llm_with_tools.ainvoke.call_count == 2

View File

@@ -1,416 +0,0 @@
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from app.core.agent_registry import ChatAgent, registry
# ── Minimal concrete agent for testing ───────────────────────────────
class _EchoAgent(ChatAgent):
def get_name(self) -> str:
return "_echo"
def get_description(self) -> str:
return "Echo agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return query
# ── Helpers ───────────────────────────────────────────────────────────
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
msg = AIMessage(content=content)
if tool_calls:
msg.tool_calls = tool_calls
else:
msg.tool_calls = []
return msg
def _make_tool(name: str, return_value: Any) -> MagicMock:
t = MagicMock()
t.name = name
t.ainvoke = AsyncMock(return_value=return_value)
return t
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
chunks = []
for tok in tokens:
c = MagicMock()
c.content = tok
chunks.append(c)
return chunks
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
tokens: list[str] = []
async for tok in agent._tool_loop_stream(llm, messages, tools):
tokens.append(tok)
return tokens
# ── tool_results initialised ─────────────────────────────────────────
def test_tool_results_init():
agent = _EchoAgent()
assert agent.tool_results == []
# ── _tool_loop: no tool calls ────────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_no_tools():
agent = _EchoAgent()
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert result == "Hello!"
assert agent.tool_results == []
# ── _tool_loop: with one tool call + result capture ──────────────────
@pytest.mark.asyncio
async def test_tool_loop_captures_tool_results():
agent = _EchoAgent()
# Mock execute_on_client to return structured data via the tool
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
async def fake_executor(payload: dict) -> dict:
return raw_result
# AIMessage with a tool call, then a final answer
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Here are your tasks.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
llm.ainvoke = AsyncMock(return_value=final_msg)
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
# Patch the tool to actually call execute_on_client
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
rows = res.get("rows", [])
return "\n".join(r["title"] for r in rows)
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
result = await agent._tool_loop(
llm, [HumanMessage(content="list my tasks")], [mock_tool]
)
finally:
clear_client_executor()
assert result == "Here are your tasks."
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop: tool_results reset on each call ──────────────────────
@pytest.mark.asyncio
async def test_tool_loop_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
assert agent.tool_results == []
# ── _tool_loop: unknown tool name ────────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_unknown_tool():
agent = _EchoAgent()
# No known tools — model still calls a non-existent one; loop handles gracefully
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
)
final_msg = _make_ai_message("Handled.")
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
assert result == "Handled."
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
@pytest.mark.asyncio
async def test_tool_loop_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
fallback = _make_ai_message("Fallback.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
# Returns tool_call_msg on every iteration
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
llm.ainvoke = AsyncMock(return_value=fallback)
mock_tool = _make_tool("t", "ok")
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
assert result == "Fallback."
assert llm_with_tools.ainvoke.call_count == 2
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_no_tools_yields_tokens():
agent = _EchoAgent()
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
no_tool_msg = _make_ai_message("irrelevant")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["Hello", " ", "world"]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
assert tokens == ["Hello", " ", "world"]
assert agent.tool_results == []
# ── _tool_loop_stream: one tool call then streaming final ─────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_with_tool_call():
agent = _EchoAgent()
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
async def fake_executor(payload: dict) -> dict:
return raw_result
tool_call_msg = _make_ai_message(
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
)
# After tools run, ainvoke returns no more tool calls
no_more_tools_msg = _make_ai_message("Task found.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
for tok in ["Task", " ", "found."]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
return res.get("row", {}).get("title", "")
mock_tool = _make_tool("get_task", "Deploy")
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
)
finally:
clear_client_executor()
assert tokens == ["Task", " ", "found."]
assert len(agent.tool_results) == 1
assert agent.tool_results[0] == raw_result
# ── _tool_loop_stream: tool_results reset on each call ───────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_resets_tool_results():
agent = _EchoAgent()
agent.tool_results = [{"old": True}]
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
c = MagicMock()
c.content = "ok"
yield c
llm.astream = fake_astream
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert agent.tool_results == []
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_skips_empty_chunks():
agent = _EchoAgent()
no_tool_msg = _make_ai_message("")
llm = AsyncMock()
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
async def fake_astream(msgs):
for tok in ["", "hello", "", " world", ""]:
c = MagicMock()
c.content = tok
yield c
llm.astream = fake_astream
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
assert tokens == ["hello", " world"]
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
@pytest.mark.asyncio
async def test_tool_loop_stream_max_iter():
agent = _EchoAgent()
always_tool = _make_ai_message(
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
)
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
async def fake_astream(msgs):
c = MagicMock()
c.content = "fallback"
yield c
llm.astream = fake_astream
mock_tool = _make_tool("t", "ok")
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [mock_tool],
)
assert tokens == ["fallback"]
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
# ── _tool_loop_stream: multiple tool results captured ────────────────
@pytest.mark.asyncio
async def test_tool_loop_stream_multiple_tool_results():
agent = _EchoAgent()
call_results = [
{"rows": [{"id": "t-1"}]},
{"rows": [{"id": "t-2"}]},
]
call_iter = iter(call_results)
async def fake_executor(payload: dict) -> dict:
return next(call_iter)
# Two tool calls in one iteration
tool_call_msg = _make_ai_message(
tool_calls=[
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
]
)
no_more_tools_msg = _make_ai_message("Done.")
llm = MagicMock()
llm_with_tools = MagicMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
async def fake_astream(msgs):
c = MagicMock()
c.content = "Done."
yield c
llm.astream = fake_astream
async def tool_side_effect(args: dict) -> str:
from app.core.ws_context import execute_on_client
res = await execute_on_client(action="select", table="tasks")
return str(res)
tool_a = _make_tool("tool_a", "")
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
tool_b = _make_tool("tool_b", "")
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
from app.core.ws_context import set_client_executor, clear_client_executor
set_client_executor(fake_executor)
try:
tokens = await _collect_stream(
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
)
finally:
clear_client_executor()
assert tokens == ["Done."]
assert len(agent.tool_results) == 2
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}

View File

@@ -1,761 +0,0 @@
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import app.agents # noqa: F401 — triggers @registry.register decorators
from app.agents.timeline_agent import TimelineAgent
from app.agents.note_agent import NoteAgent
from app.agents.project_agent import ProjectAgent
from app.agents.task_agent import TaskAgent
from app.core.agent_registry import registry
from app.core.ws_context import clear_client_executor, set_client_executor
# ── WS executor mock ──────────────────────────────────────────────────
#
# Tools call execute_on_client() which reads a ContextVar set by the WS
# handler. In unit tests there is no WS session, so we install a fake
# executor that returns plausible data for each action type.
_FAKE_ROW: dict[str, Any] = {
"id": "fake-id",
"title": "Fake Title",
"name": "Fake Name",
"status": "todo",
"priority": "medium",
"content": "Fake content",
"date": 1700000000000,
"taskId": "fake-task-id",
"author": "Alice",
"projectId": None,
}
async def _fake_executor(payload: dict) -> dict:
action = payload.get("action", "")
if action == "select":
return {"rows": []}
if action == "insert":
data = payload.get("data", {})
return {"row": {**_FAKE_ROW, **data}}
if action == "update":
data = payload.get("data", {})
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
return {"row": row}
if action == "delete":
return {"deleted": True}
if action == "get":
data = payload.get("data", {})
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
if action == "vector_upsert":
return {"ok": True}
return {}
@pytest.fixture(autouse=True)
def ws_executor():
"""Install a fake WS executor for every test so tools can run without a real WS."""
set_client_executor(_fake_executor)
yield
clear_client_executor()
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
msg = MagicMock()
msg.content = response_text
msg.tool_calls = []
llm = MagicMock()
bound = MagicMock()
bound.ainvoke = AsyncMock(return_value=msg)
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=msg)
return llm
def _mock_llm_with_tool_call(
tool_name: str, tool_args: dict[str, Any], final_text: str
) -> MagicMock:
"""Mock LLM that fires one tool call then returns *final_text*."""
tool_msg = MagicMock()
tool_msg.content = ""
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
final_msg = MagicMock()
final_msg.content = final_text
final_msg.tool_calls = []
bound = MagicMock()
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
llm = MagicMock()
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=final_msg)
return llm
# ── Registration ──────────────────────────────────────────────────────
class TestAgentRegistration:
def test_all_agents_registered(self) -> None:
names = {a["name"] for a in registry.list_agents()}
assert {
"task_agent", "timeline_agent", "project_agent", "note_agent"
}.issubset(names)
def test_registry_returns_correct_types(self) -> None:
assert isinstance(registry.get("task_agent"), TaskAgent)
assert isinstance(registry.get("timeline_agent"), TimelineAgent)
assert isinstance(registry.get("project_agent"), ProjectAgent)
assert isinstance(registry.get("note_agent"), NoteAgent)
def test_descriptions_present(self) -> None:
for agent_info in registry.list_agents():
assert agent_info["description"], f"Empty description: {agent_info['name']}"
# ── TaskAgent ─────────────────────────────────────────────────────────
class TestTaskAgent:
def test_name(self) -> None:
assert TaskAgent().get_name() == "task_agent"
def test_description(self) -> None:
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
def test_get_tools_count(self) -> None:
assert len(TaskAgent().get_tools()) == 8
def test_tool_names(self) -> None:
names = {t.name for t in TaskAgent().get_tools()}
assert names == {
"list_tasks",
"create_task",
"update_task",
"delete_task",
"list_tasks_due_today",
"list_task_comments",
"add_task_comment",
"delete_task_comment",
}
@pytest.mark.asyncio
async def test_handle_returns_string(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Task created.")
result = await TaskAgent().handle("create a task", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Here are your tasks.")
result = await TaskAgent().handle("list my tasks", {})
assert result == "Here are your tasks."
@pytest.mark.asyncio
async def test_handle_with_create_task_tool_call(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_task",
{"title": "Buy groceries", "priority": "low"},
"Task 'Buy groceries' created.",
)
result = await TaskAgent().handle("add a grocery task", {})
assert result == "Task 'Buy groceries' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TaskAgent().handle("help", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_accepts_rich_context(self) -> None:
context = {
"user_profile": {"id": "u1", "tier": "pro"},
"recent_tasks": [{"id": "t1", "title": "Old task"}],
}
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Tasks listed.")
result = await TaskAgent().handle("show tasks", context)
assert isinstance(result, str)
class TestTaskAgentTools:
@pytest.mark.asyncio
async def test_list_tasks_defaults(self) -> None:
from app.agents.task_agent import list_tasks
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_tasks.ainvoke({})
m.assert_called_once_with(
action="select", table="tasks",
filters={"projectId": None, "status": None, "search": None, "orderBy": None},
)
assert result == "No tasks found matching the given filters."
@pytest.mark.asyncio
async def test_list_tasks_with_status_filter(self) -> None:
from app.agents.task_agent import list_tasks
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_tasks.ainvoke({"status": "done"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["filters"]["status"] == "done"
@pytest.mark.asyncio
async def test_create_task_defaults(self) -> None:
from app.agents.task_agent import create_task
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_task.ainvoke({"title": "Test task"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["table"] == "tasks"
assert call_kwargs["data"]["title"] == "Test task"
assert call_kwargs["data"]["status"] == "todo"
assert call_kwargs["data"]["priority"] == "medium"
assert "Test task" in result
@pytest.mark.asyncio
async def test_create_task_with_all_fields(self) -> None:
from app.agents.task_agent import create_task
fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await create_task.ainvoke({
"title": "Deploy", "priority": "high", "status": "in_progress",
"project_id": "p1", "is_ai_suggested": 1,
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["priority"] == "high"
assert call_kwargs["data"]["status"] == "in_progress"
assert call_kwargs["data"]["projectId"] == "p1"
assert call_kwargs["data"]["isAiSuggested"] == 1
@pytest.mark.asyncio
async def test_update_task_with_status(self) -> None:
from app.agents.task_agent import update_task
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "update"
assert call_kwargs["data"]["id"] == "t1"
assert call_kwargs["data"]["updates"]["status"] == "done"
assert "t1" in result
@pytest.mark.asyncio
async def test_update_task_empty_updates(self) -> None:
from app.agents.task_agent import update_task
fake_row = {"id": "t1", "title": "Task", "status": "todo"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_task.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_task(self) -> None:
from app.agents.task_agent import delete_task
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_task.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "tasks"
assert call_kwargs["data"]["id"] == "t1"
assert "t1" in result
@pytest.mark.asyncio
async def test_list_tasks_due_today(self) -> None:
from app.agents.task_agent import list_tasks_due_today
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_tasks_due_today.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "tasks"
assert "dueDateFrom" in call_kwargs["filters"]
assert result == "No tasks are due today."
@pytest.mark.asyncio
async def test_list_task_comments(self) -> None:
from app.agents.task_agent import list_task_comments
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_task_comments.ainvoke({"task_id": "t1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "taskComments"
assert call_kwargs["filters"]["taskId"] == "t1"
assert "t1" in result
@pytest.mark.asyncio
async def test_add_task_comment(self) -> None:
from app.agents.task_agent import add_task_comment
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await add_task_comment.ainvoke({
"task_id": "t1", "author": "Alice", "content": "Looks good!",
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["table"] == "taskComments"
assert call_kwargs["data"]["taskId"] == "t1"
assert call_kwargs["data"]["author"] == "Alice"
assert call_kwargs["data"]["content"] == "Looks good!"
assert "Alice" in result
@pytest.mark.asyncio
async def test_delete_task_comment(self) -> None:
from app.agents.task_agent import delete_task_comment
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "taskComments"
assert call_kwargs["data"]["id"] == "c1"
assert "c1" in result
# ── TimelineAgent ───────────────────────────────────────────────────
class TestTimelineAgent:
def test_name(self) -> None:
assert TimelineAgent().get_name() == "timeline_agent"
def test_description(self) -> None:
assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(TimelineAgent().get_tools()) == 4
def test_tool_names(self) -> None:
names = {t.name for t in TimelineAgent().get_tools()}
assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("No timelines found.")
result = await TimelineAgent().handle("list timelines", {})
assert result == "No timelines found."
@pytest.mark.asyncio
async def test_handle_with_create_tool_call(self) -> None:
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_timeline",
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
"Timeline 'MVP Launch' created.",
)
result = await TimelineAgent().handle("add MVP timeline", {})
assert result == "Timeline 'MVP Launch' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TimelineAgent().handle("show milestones", {})
assert isinstance(result, str)
class TestTimelineAgentTools:
@pytest.mark.asyncio
async def test_list_timelines_no_project(self) -> None:
from app.agents.timeline_agent import list_timelines
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_timelines.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "timelines"
assert call_kwargs["filters"]["projectId"] is None
assert result == "No timelines found."
@pytest.mark.asyncio
async def test_list_timelines_with_project(self) -> None:
from app.agents.timeline_agent import list_timelines
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_timelines.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_create_timeline(self) -> None:
from app.agents.timeline_agent import create_timeline
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_timeline.ainvoke({
"project_id": "p1", "title": "Beta release", "date": 1700000000000,
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["table"] == "timelines"
assert call_kwargs["data"]["projectId"] == "p1"
assert call_kwargs["data"]["title"] == "Beta release"
assert call_kwargs["data"]["date"] == 1700000000000
assert "Beta release" in result
@pytest.mark.asyncio
async def test_create_timeline_ai_suggested(self) -> None:
from app.agents.timeline_agent import create_timeline
fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await create_timeline.ainvoke({
"project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
})
call_kwargs = m.call_args.kwargs
assert call_kwargs["data"]["isAiSuggested"] == 1
assert call_kwargs["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_update_timeline_approve(self) -> None:
from app.agents.timeline_agent import update_timeline
fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "update"
assert call_kwargs["data"]["id"] == "c1"
assert call_kwargs["data"]["updates"]["isApproved"] == 1
assert "c1" in result
@pytest.mark.asyncio
async def test_update_timeline_empty_updates(self) -> None:
from app.agents.timeline_agent import update_timeline
fake_row = {"id": "c1", "title": "MVP"}
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_timeline.ainvoke({"timeline_id": "c1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_timeline(self) -> None:
from app.agents.timeline_agent import delete_timeline
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_timeline.ainvoke({"timeline_id": "c1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "timelines"
assert call_kwargs["data"]["id"] == "c1"
assert "c1" in result
# ── ProjectAgent ──────────────────────────────────────────────────────
class TestProjectAgent:
def test_name(self) -> None:
assert ProjectAgent().get_name() == "project_agent"
def test_description(self) -> None:
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
def test_get_tools_count(self) -> None:
assert len(ProjectAgent().get_tools()) == 6
def test_tool_names(self) -> None:
names = {t.name for t in ProjectAgent().get_tools()}
assert names == {
"list_projects",
"list_all_projects",
"get_project",
"create_project",
"update_project",
"delete_project",
}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Project Alpha is active.")
result = await ProjectAgent().handle("show my projects", {})
assert result == "Project Alpha is active."
@pytest.mark.asyncio
async def test_handle_with_create_project_tool_call(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_project",
{"name": "Pippo"},
"Project 'Pippo' created.",
)
result = await ProjectAgent().handle("create project Pippo", {})
assert result == "Project 'Pippo' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await ProjectAgent().handle("archive old project", {})
assert isinstance(result, str)
class TestProjectAgentTools:
@pytest.mark.asyncio
async def test_list_projects_defaults(self) -> None:
from app.agents.project_agent import list_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_projects.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "projects"
assert call_kwargs["filters"]["includeArchived"] is False
assert result == "No projects found."
@pytest.mark.asyncio
async def test_list_projects_include_archived(self) -> None:
from app.agents.project_agent import list_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_projects.ainvoke({"include_archived": 1})
assert m.call_args.kwargs["filters"]["includeArchived"] is True
@pytest.mark.asyncio
async def test_list_all_projects(self) -> None:
from app.agents.project_agent import list_all_projects
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_all_projects.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "projects"
assert result == "No projects found."
@pytest.mark.asyncio
async def test_get_project(self) -> None:
from app.agents.project_agent import get_project
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await get_project.ainvoke({"project_id": "p1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "get"
assert call_kwargs["table"] == "projects"
assert call_kwargs["data"]["id"] == "p1"
assert "Alpha" in result
@pytest.mark.asyncio
async def test_create_project_name_only(self) -> None:
from app.agents.project_agent import create_project
fake_row = {"id": "p1", "name": "Alpha"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await create_project.ainvoke({"name": "Alpha"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "insert"
assert call_kwargs["data"]["name"] == "Alpha"
assert call_kwargs["data"]["clientId"] is None
assert "Alpha" in result
@pytest.mark.asyncio
async def test_create_project_with_client(self) -> None:
from app.agents.project_agent import create_project
fake_row = {"id": "p1", "name": "Beta"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
@pytest.mark.asyncio
async def test_update_project_archive(self) -> None:
from app.agents.project_agent import update_project
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "update"
assert call_kwargs["data"]["id"] == "p1"
assert call_kwargs["data"]["updates"]["status"] == "archived"
assert "p1" in result
@pytest.mark.asyncio
async def test_update_project_empty_updates(self) -> None:
from app.agents.project_agent import update_project
fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_project.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_project(self) -> None:
from app.agents.project_agent import delete_project
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_project.ainvoke({"project_id": "p1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["data"]["id"] == "p1"
assert "p1" in result
# ── NoteAgent ─────────────────────────────────────────────────────────
class TestNoteAgent:
def test_name(self) -> None:
assert NoteAgent().get_name() == "note_agent"
def test_description(self) -> None:
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(NoteAgent().get_tools()) == 5
def test_tool_names(self) -> None:
names = {t.name for t in NoteAgent().get_tools()}
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Note created.")
result = await NoteAgent().handle("create a note", {})
assert result == "Note created."
@pytest.mark.asyncio
async def test_handle_with_create_note_tool_call(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_note",
{"title": "Daily log", "content": "# Today\nAll good."},
"Note 'Daily log' created.",
)
result = await NoteAgent().handle("log today's progress", {})
assert result == "Note 'Daily log' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await NoteAgent().handle("show notes", {})
assert isinstance(result, str)
class TestNoteAgentTools:
@pytest.mark.asyncio
async def test_list_notes_no_project(self) -> None:
from app.agents.note_agent import list_notes
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
result = await list_notes.ainvoke({})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "select"
assert call_kwargs["table"] == "notes"
assert call_kwargs["filters"]["projectId"] is None
assert result == "No notes found."
@pytest.mark.asyncio
async def test_list_notes_with_project(self) -> None:
from app.agents.note_agent import list_notes
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"rows": []}
await list_notes.ainvoke({"project_id": "p1"})
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_get_note(self) -> None:
from app.agents.note_agent import get_note
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
result = await get_note.ainvoke({"note_id": "n1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "get"
assert call_kwargs["table"] == "notes"
assert call_kwargs["data"]["id"] == "n1"
assert "Daily log" in result
@pytest.mark.asyncio
async def test_create_note_minimal(self) -> None:
from app.agents.note_agent import create_note
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
m.return_value = {"row": fake_row}
me.return_value = [0.0] * 1536
result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
# First call: insert; second call: vector_upsert
first_call = m.call_args_list[0].kwargs
assert first_call["action"] == "insert"
assert first_call["table"] == "notes"
assert first_call["data"]["title"] == "Daily log"
assert first_call["data"]["content"] == "# Today\nAll good."
assert first_call["data"]["projectId"] is None
assert "Daily log" in result
@pytest.mark.asyncio
async def test_create_note_with_project(self) -> None:
from app.agents.note_agent import create_note
fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
m.return_value = {"row": fake_row}
me.return_value = [0.0] * 1536
await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
first_call = m.call_args_list[0].kwargs
assert first_call["data"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_update_note_content_only(self) -> None:
from app.agents.note_agent import update_note
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
m.return_value = {"row": fake_row}
me.return_value = [0.0] * 1536
result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
first_call = m.call_args_list[0].kwargs
assert first_call["action"] == "update"
assert first_call["data"]["id"] == "n1"
assert first_call["data"]["updates"]["content"] == "# Updated content"
assert "title" not in first_call["data"]["updates"]
assert "n1" in result
@pytest.mark.asyncio
async def test_update_note_empty_updates(self) -> None:
from app.agents.note_agent import update_note
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"row": fake_row}
await update_note.ainvoke({"note_id": "n1"})
assert m.call_args.kwargs["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_note(self) -> None:
from app.agents.note_agent import delete_note
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
m.return_value = {"deleted": True}
result = await delete_note.ainvoke({"note_id": "n1"})
call_kwargs = m.call_args.kwargs
assert call_kwargs["action"] == "delete"
assert call_kwargs["table"] == "notes"
assert call_kwargs["data"]["id"] == "n1"
assert "n1" in result

View File

@@ -1,286 +0,0 @@
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
from __future__ import annotations
import pytest
from app.core.execution_plan import (
ExecutionPlanBuilder,
PlanCache,
PromptTemplateRegistry,
plan_cache,
template_registry,
)
from app.schemas import ExecutionPlan
# ── PromptTemplateRegistry ────────────────────────────────────────────
class TestPromptTemplateRegistry:
def test_register_and_get(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_foo", "You are a foo agent.")
assert reg.get("tpl_foo") == "You are a foo agent."
def test_get_unknown_raises_key_error(self) -> None:
reg = PromptTemplateRegistry()
with pytest.raises(KeyError, match="tpl_missing"):
reg.get("tpl_missing")
def test_has_returns_true_for_registered(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "prompt text")
assert reg.has("tpl_x") is True
def test_has_returns_false_for_unregistered(self) -> None:
reg = PromptTemplateRegistry()
assert reg.has("tpl_missing") is False
def test_list_ids_returns_all_registered_ids(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_a", "a")
reg.register("tpl_b", "b")
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
def test_list_ids_does_not_return_prompt_text(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_secret", "top secret prompt")
ids = reg.list_ids()
assert "top secret prompt" not in ids
def test_overwrite_existing_template(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "v1")
reg.register("tpl_x", "v2")
assert reg.get("tpl_x") == "v2"
def test_empty_registry_has_no_ids(self) -> None:
reg = PromptTemplateRegistry()
assert reg.list_ids() == []
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
class TestExecutionPlanBuilder:
def test_builds_empty_plan(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert plan.agent == "task_agent"
assert plan.steps == []
def test_add_step_basic(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("create_task", {"priority": "high"})
.build()
)
assert len(plan.steps) == 1
assert plan.steps[0].action == "create_task"
assert plan.steps[0].variables == {"priority": "high"}
assert plan.steps[0].prompt_template is None
assert plan.steps[0].data_from_step is None
def test_add_step_no_params(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
assert plan.steps[0].variables is None
def test_add_llm_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_default", {"message": "hi"})
.build()
)
assert plan.steps[0].action == "llm"
assert plan.steps[0].prompt_template == "tpl_task_default"
assert plan.steps[0].variables == {"message": "hi"}
def test_add_llm_step_no_variables(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
assert plan.steps[0].variables is None
def test_add_data_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("fetch_data")
.add_data_step("transform", data_from_step=0)
.build()
)
assert plan.steps[1].action == "transform"
assert plan.steps[1].data_from_step == 0
def test_fluent_chaining_returns_builder(self) -> None:
builder = ExecutionPlanBuilder("analytics_agent")
result = builder.add_step("a")
assert result is builder
def test_fluent_chain_multiple_steps(self) -> None:
plan = (
ExecutionPlanBuilder("analytics_agent")
.add_llm_step("tpl_analytics_default")
.add_step("format_output")
.add_data_step("store", data_from_step=0)
.build()
)
assert len(plan.steps) == 3
def test_build_validates_data_from_step_out_of_range(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
def test_build_validates_data_from_step_self_reference(self) -> None:
"""data_from_step=0 on the first step (index 0) is invalid."""
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
def test_build_validates_data_from_step_negative(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
def test_valid_data_from_step_at_index_two(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_step("step1")
.add_data_step("step2", data_from_step=1)
.build()
)
assert plan.steps[2].data_from_step == 1
def test_data_from_step_zero_valid_at_index_one(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_data_step("step1", data_from_step=0)
.build()
)
assert plan.steps[1].data_from_step == 0
def test_build_returns_new_plan_each_call(self) -> None:
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
plan1 = builder.build()
plan2 = builder.build()
assert plan1 is not plan2
assert plan1.steps == plan2.steps
def test_plan_is_execution_plan_instance(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert isinstance(plan, ExecutionPlan)
# ── PlanCache ─────────────────────────────────────────────────────────
class TestPlanCache:
def _plan(self, agent: str = "a") -> ExecutionPlan:
return ExecutionPlanBuilder(agent).build()
def test_cache_and_get(self) -> None:
cache = PlanCache()
plan = self._plan()
cache.cache_plan("key1", plan)
assert cache.get_plan("key1") is plan
def test_get_missing_returns_none(self) -> None:
cache = PlanCache()
assert cache.get_plan("nonexistent") is None
def test_get_all_playbooks_empty(self) -> None:
cache = PlanCache()
assert cache.get_all_playbooks() == []
def test_get_all_playbooks_returns_all_stored(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
playbooks = cache.get_all_playbooks()
assert len(playbooks) == 2
assert p1 in playbooks
assert p2 in playbooks
def test_lru_evicts_oldest_entry(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.cache_plan("k3", p3) # k1 should be evicted
assert cache.get_plan("k1") is None
assert cache.get_plan("k2") is p2
assert cache.get_plan("k3") is p3
def test_lru_access_updates_recency(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.get_plan("k1") # k1 is now most-recently used
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
assert cache.get_plan("k1") is p1
assert cache.get_plan("k2") is None
assert cache.get_plan("k3") is p3
def test_overwrite_existing_key(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("same_key", p1)
cache.cache_plan("same_key", p2)
assert cache.get_plan("same_key") is p2
assert len(cache.get_all_playbooks()) == 1
def test_overwrite_does_not_consume_capacity(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k1", p2) # overwrite, not a new slot
cache.cache_plan("k2", p1) # should fit without eviction
assert cache.get_plan("k1") is p2
assert cache.get_plan("k2") is p1
# ── Module-level singletons ───────────────────────────────────────────
class TestModuleSingletons:
def test_template_registry_has_all_agent_defaults(self) -> None:
for agent in ("task_agent", "timeline_agent", "project_agent", "note_agent"):
assert template_registry.has(f"tpl_{agent}_default"), (
f"Missing template: tpl_{agent}_default"
)
def test_template_registry_has_operation_templates(self) -> None:
assert template_registry.has("tpl_task_extract_from_project")
assert template_registry.has("tpl_note_weekly_summary")
def test_template_registry_get_returns_non_empty_string(self) -> None:
text = template_registry.get("tpl_task_agent_default")
assert isinstance(text, str)
assert len(text) > 0
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
assert len(plan_cache.get_all_playbooks()) >= 2
def test_playbook_create_tasks_from_project(self) -> None:
plan = plan_cache.get_plan("create_tasks_from_project")
assert plan is not None
assert plan.agent == "project_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
assert plan.steps[1].data_from_step == 0
def test_playbook_generate_weekly_note(self) -> None:
plan = plan_cache.get_plan("generate_weekly_note")
assert plan is not None
assert plan.agent == "note_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
assert plan.steps[1].data_from_step == 0
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
"""Plans must not embed prompt text — only template IDs."""
for plan in plan_cache.get_all_playbooks():
for step in plan.steps:
if step.prompt_template is not None:
assert step.prompt_template.startswith("tpl_"), (
f"prompt_template looks like raw text: {step.prompt_template!r}"
)

View File

@@ -250,15 +250,15 @@ def test_home_request_calls_memory_middleware(client):
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
async def _mock_stream(user_id, message, context, reg=None): async def _mock_stream(user_id, message, context, db_session_factory=None):
# Verify memory context was injected # Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"} assert context.get("core_memory") == {"tz": "UTC"}
yield "task_agent", "" yield ("token", "Done")
yield "task_agent", '{"type": "text", "content": "Done"}' yield ("mutations", [])
with ( with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware), patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream), patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
): ):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({

View File

@@ -20,7 +20,6 @@ from jose import jwt
from app.config.settings import settings from app.config.settings import settings
from app.db import get_session from app.db import get_session
from app.main import app from app.main import app
from app.schemas import ChatResponse
from tests.conftest import TEST_USER_IDS from tests.conftest import TEST_USER_IDS
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -50,7 +49,6 @@ _CHAT_BODY = {
"recent_tasks": [], "recent_tasks": [],
"conversation_history": [], "conversation_history": [],
}, },
"execution_mode": "direct",
} }
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
class TestSanitizerMiddleware: class TestSanitizerMiddleware:
"""Mock ``orchestrate`` to inject controlled strings into chat responses.""" """Mock ``run_home`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat" _CHAT_PATH = "/api/v1/chat"
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro") return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict: def _post_chat(self, client: TestClient, response_text: str) -> dict:
mock_response = ChatResponse(response=response_text, actions=[])
with patch( with patch(
"app.api.routes.chat.orchestrate", "app.api.routes.chat.run_home",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=mock_response, return_value=response_text,
): ):
resp = client.post( resp = client.post(
self._CHAT_PATH, self._CHAT_PATH,

View File

@@ -1,347 +0,0 @@
"""Integration tests for the orchestrator module."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.orchestrator import (
classify_intent,
orchestrate,
orchestrate_stream,
route_pipeline,
route_single,
)
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
# ── Stub agents ──────────────────────────────────────────────────────
class _TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks: create, update, list, suggest"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"task: {query}"
class _CalendarAgent(ChatAgent):
def get_name(self) -> str:
return "calendar_agent"
def get_description(self) -> str:
return "Calendar management: events, conflicts, scheduling"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"calendar: {query}"
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that always produces *response_text*."""
msg = MagicMock()
msg.content = response_text
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=msg)
return llm
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the AgentRegistry singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
r = AgentRegistry()
r.register(_TaskAgent)
r.register(_CalendarAgent)
return r
# ── classify_intent ───────────────────────────────────────────────────
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
result = await classify_intent("add a task", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
result = await classify_intent("schedule a meeting", {}, reg)
assert result == "calendar_agent"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("nonexistent_agent")
result = await classify_intent("do something", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
empty_reg = AgentRegistry()
# No LLM should be instantiated — early return path
with patch("app.core.orchestrator._make_llm") as mock_cls:
result = await classify_intent("anything", {}, empty_reg)
mock_cls.assert_not_called()
assert result == "task_agent"
@pytest.mark.asyncio
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm(" task_agent \n")
result = await classify_intent("create task", {}, reg)
assert result == "task_agent"
# ── route_single ─────────────────────────────────────────────────────
class TestRouteSingle:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert result.response == "task: create a task"
@pytest.mark.asyncio
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await route_single("nonexistent", "hello", {}, reg)
@pytest.mark.asyncio
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "hi", {}, reg)
assert result.actions == []
# ── route_pipeline ────────────────────────────────────────────────────
class TestRoutePipeline:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert result.response == "synthesized result"
@pytest.mark.asyncio
async def test_passes_previous_results_to_subsequent_agents(
self, reg: AgentRegistry
) -> None:
"""Each agent after the first should receive prior outputs in context."""
received_contexts: list[dict[str, Any]] = []
class _CapturingAgent(ChatAgent):
def get_name(self) -> str:
return "capture"
def get_description(self) -> str:
return "captures context for testing"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
received_contexts.append(dict(context))
return "captured"
reg.register(_CapturingAgent)
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("done")
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
# The second agent (capture) must have received previous results
assert len(received_contexts) == 1
assert "previous_results" in received_contexts[0]
assert received_contexts[0]["previous_results"] == ["task: hi"]
@pytest.mark.asyncio
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("single result")
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
assert result.response == "single result"
# ── orchestrate ───────────────────────────────────────────────────────
class TestOrchestrate:
@pytest.mark.asyncio
async def test_direct_mode_returns_chat_response(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
assert result.response == "task: add a task"
@pytest.mark.asyncio
async def test_plan_mode_returns_execution_plan(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan my tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
@pytest.mark.asyncio
async def test_plan_mode_agent_matches_classified(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
request = ChatRequest(
message="schedule something", execution_mode="plan"
)
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.agent == "calendar_agent"
@pytest.mark.asyncio
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert len(result.steps) >= 1
@pytest.mark.asyncio
async def test_plan_mode_template_id_contains_agent_name(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.steps[0].prompt_template is not None
assert "task_agent" in result.steps[0].prompt_template
@pytest.mark.asyncio
async def test_default_execution_mode_is_direct(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
# execution_mode defaults to "direct"
request = ChatRequest(message="help me")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
# ── orchestrate_stream ────────────────────────────────────────────────
class TestOrchestrateStream:
@pytest.mark.asyncio
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
assert len(chunks) >= 1
@pytest.mark.asyncio
async def test_all_chunks_are_plain_text(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
# orchestrate_stream yields plain text chunks only — no JSON final frame
for chunk in chunks:
assert isinstance(chunk, str)
@pytest.mark.asyncio
async def test_concatenated_chunks_equal_full_response(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="create a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
full_text = "".join(chunks)
assert full_text == "task: create a task"
@pytest.mark.asyncio
async def test_text_chunks_before_final_frame(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(
message="x" * 200, execution_mode="direct"
) # long enough to produce multiple chunks
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
# All but the last chunk should be plain text (not valid final JSON)
non_final = chunks[:-1]
for chunk in non_final:
try:
parsed = json.loads(chunk)
assert parsed.get("done") is not True
except json.JSONDecodeError:
pass # plain text chunk — expected

View File

@@ -1,236 +0,0 @@
"""Tests for v3 orchestrator functions (Step 3)."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Any
from app.core.agent_registry import ChatAgent, AgentRegistry
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
# ── Minimal agent for testing ─────────────────────────────────────────
class _FixedAgent(ChatAgent):
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._name = name
self._tokens = tokens or ["Hello", " world"]
def get_name(self) -> str:
return self._name
def get_description(self) -> str:
return "Fixed agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "".join(self._tokens)
async def handle_stream(self, query: str, context: dict[str, Any]):
for tok in self._tokens:
yield tok
# ── Mock registry factory ─────────────────────────────────────────────
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
reg = MagicMock(spec=AgentRegistry)
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
reg.get.return_value = agent
return reg
# ── orchestrate_v3 ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_orchestrate_v3_returns_agent_name_and_instance():
agent = _FixedAgent("task_agent")
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
name, inst = await orchestrate_v3(
user_id="u-1", message="fix a bug", context={}, reg=reg
)
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_classify_called_with_message_and_context():
agent = _FixedAgent("note_agent")
reg = _make_registry("note_agent", agent)
ctx = {"some": "context"}
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
mock_classify.assert_awaited_once()
call_args = mock_classify.call_args
assert call_args[0][0] == "take a note"
assert call_args[0][1] == ctx
@pytest.mark.asyncio
async def test_orchestrate_v3_uses_default_registry_when_none():
agent = _FixedAgent("task_agent")
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
assert name == "task_agent"
assert inst is agent
@pytest.mark.asyncio
async def test_orchestrate_v3_get_called_with_agent_name():
agent = _FixedAgent("timeline_agent")
reg = _make_registry("timeline_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")):
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
reg.get.assert_called_once_with("timeline_agent")
# ── orchestrate_v3_stream ─────────────────────────────────────────────
async def _collect(gen) -> list[tuple[str, str]]:
results: list[tuple[str, str]] = []
async for item in gen:
results.append(item)
return results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
agent = _FixedAgent("task_agent", tokens=["token1"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# First item must be (agent_name, "") — domain signal
assert results[0] == ("task_agent", "")
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
# All items are (agent_name, token) pairs
assert all(name == "task_agent" for name, _ in results)
tokens = [tok for _, tok in results]
assert tokens[0] == "" # domain signal
assert tokens[1:] == ["Hello", " ", "world"]
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_different_agent():
agent = _FixedAgent("note_agent", tokens=["note"])
reg = _make_registry("note_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
results = await _collect(gen)
assert results[0] == ("note_agent", "")
assert ("note_agent", "note") in results
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
agent = _FixedAgent("task_agent", tokens=["x"])
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
patch("app.core.orchestrator._default_registry") as mock_reg:
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
mock_reg.get.return_value = agent
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
results = await _collect(gen)
assert results[0][0] == "task_agent"
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_empty_token_list():
"""Agent with no tokens still emits the domain signal."""
class _EmptyAgent(_FixedAgent):
async def handle_stream(self, query: str, context: dict[str, Any]):
return
yield # makes it a generator
agent = _EmptyAgent("task_agent", tokens=[])
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
assert results == [("task_agent", "")] # only domain signal
@pytest.mark.asyncio
async def test_orchestrate_v3_stream_full_text_correct():
"""Concatenating all non-domain tokens reconstructs the full response."""
tokens = ["The", " ", "task", " ", "is", " ", "done."]
agent = _FixedAgent("task_agent", tokens=tokens)
reg = _make_registry("task_agent", agent)
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
results = await _collect(gen)
text = "".join(tok for _, tok in results[1:]) # skip domain signal
assert text == "The task is done."
# ── handle_stream default implementation ─────────────────────────────
@pytest.mark.asyncio
async def test_handle_stream_default_yields_full_response():
"""Default handle_stream yields handle() result as a single chunk."""
class _SimpleAgent(ChatAgent):
def get_name(self) -> str:
return "_simple"
def get_description(self) -> str:
return ""
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "simple response"
agent = _SimpleAgent()
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["simple response"]
@pytest.mark.asyncio
async def test_handle_stream_override_used_by_stream():
"""_FixedAgent.handle_stream override yields individual tokens."""
agent = _FixedAgent("t", tokens=["a", "b", "c"])
tokens = [tok async for tok in agent.handle_stream("q", {})]
assert tokens == ["a", "b", "c"]

View File

@@ -16,15 +16,15 @@ from app.schemas import (
# ── helpers ─────────────────────────────────────────────────────────────────── # ── helpers ───────────────────────────────────────────────────────────────────
async def _stream(*pairs: tuple[str, str]): async def _stream(*events: tuple[str, object]):
"""Async generator that yields (agent_name, token) pairs.""" """Async generator that yields (event_type, data) tuples."""
for pair in pairs: for event in events:
yield pair yield event
async def collect(formatter, token_stream): async def collect(formatter, event_stream):
frames = [] frames = []
async for frame in formatter.format(token_stream): async for frame in formatter.format(event_stream):
frames.append(frame) frames.append(frame)
return frames return frames
@@ -32,13 +32,14 @@ async def collect(formatter, token_stream):
# ── HomeFormatter ───────────────────────────────────────────────────────────── # ── HomeFormatter ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_home_formatter_text_block(): async def test_home_formatter_text_token():
req_id = "req-1" req_id = "req-1"
tokens = [ events = [
("task_agent", '{"type": "text", "content": "Hello world"}'), ("token", "Hello world"),
("mutations", []),
] ]
formatter = HomeFormatter(request_id=req_id, tool_results=[]) formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*tokens)) frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[0], WsStreamStart)
assert frames[0].request_id == req_id assert frames[0].request_id == req_id
@@ -48,104 +49,94 @@ async def test_home_formatter_text_block():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_home_formatter_chart_block(): async def test_home_formatter_entity_ref_from_tool_end():
req_id = "req-2" req_id = "req-2"
chart_json = ( events = [
'{"type": "chart", "chartType": "bar", ' ("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}),
'"title": "Tasks", "data": [{"x": 1}], ' ("token", "Here are your tasks."),
'"config": {"x": {"label": "X", "color": "#fff"}}}' ("mutations", []),
) ]
formatter = HomeFormatter(request_id=req_id, tool_results=[]) formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("task_agent", chart_json))) frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1 assert len(block_frames) == 1
assert block_frames[0].block_type == "chart" assert block_frames[0].block_type == "entity_ref"
assert block_frames[0].data["chartType"] == "bar" assert block_frames[0].data["entity"] == "tasks"
assert block_frames[0].data["result"] == "Found 3 tasks."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_home_formatter_invalid_chart_skipped(): async def test_home_formatter_unknown_agent_no_block():
req_id = "req-3" req_id = "req-3"
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}' events = [
formatter = HomeFormatter(request_id=req_id, tool_results=[]) ("tool_end", {"name": "unknown_agent", "result": "stuff"}),
frames = await collect(formatter, _stream(("task_agent", bad_chart))) ("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # invalid chart skipped assert len(block_frames) == 0 # unknown agent → no entity mapping
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_home_formatter_entity_ref_resolved(): async def test_home_formatter_mutations_in_stream_end():
req_id = "req-4" req_id = "req-4"
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}] muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
entity_json = '{"type": "entity_ref", "entity": "task"}' events = [
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results) ("token", "Done"),
frames = await collect(formatter, _stream(("task_agent", entity_json))) ("mutations", muts),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] end_frame = frames[-1]
assert len(block_frames) == 1 assert isinstance(end_frame, WsStreamEnd)
assert block_frames[0].data["entity"] == "task" assert len(end_frame.mutations) == 1
assert block_frames[0].data["items"][0]["id"] == "t1" assert end_frame.mutations[0]["action"] == "insert"
@pytest.mark.asyncio
async def test_home_formatter_entity_ref_missing_skipped():
req_id = "req-5"
entity_json = '{"type": "entity_ref", "entity": "task"}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", entity_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 0 # no tool results → skipped
@pytest.mark.asyncio
async def test_home_formatter_table_block():
req_id = "req-6"
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", table_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "table"
@pytest.mark.asyncio
async def test_home_formatter_timeline_block():
req_id = "req-7"
timeline_json = '{"type": "timeline", "timelines": [{"id": "c1", "title": "M1", "date": 123}]}'
formatter = HomeFormatter(request_id=req_id, tool_results=[])
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 1
assert block_frames[0].block_type == "timeline"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_home_formatter_frame_order(): async def test_home_formatter_frame_order():
"""stream_start is first, stream_end is last.""" """stream_start is first, stream_end is last."""
req_id = "req-8" req_id = "req-5"
formatter = HomeFormatter(request_id=req_id, tool_results=[]) formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}'))) frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
# ── FloatingFormatter ──────────────────────────────────────────────────────────── @pytest.mark.asyncio
async def test_home_formatter_multiple_tool_ends():
req_id = "req-6"
events = [
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
("tool_end", {"name": "project_agent", "result": "2 projects"}),
("token", "Overview done."),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
assert len(block_frames) == 2
entities = {b.data["entity"] for b in block_frames}
assert entities == {"tasks", "projects"}
# ── FloatingFormatter ─────────────────────────────────────────────────────────
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_floating_formatter_domain_emitted_first(): async def test_floating_formatter_domain_from_tool_end():
req_id = "pop-1" req_id = "pop-1"
formatter = FloatingFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
tokens = [ events = [
("task_agent", ""), # domain signal ("tool_end", {"name": "task_agent", "result": "ok"}),
("task_agent", "Hello"), ("token", "Hello"),
("task_agent", " there"), ("mutations", []),
] ]
frames = await collect(formatter, _stream(*tokens)) frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain) assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks" assert frames[0].domain == "tasks"
@@ -156,8 +147,12 @@ async def test_floating_formatter_domain_emitted_first():
async def test_floating_formatter_text_only(): async def test_floating_formatter_text_only():
req_id = "pop-2" req_id = "pop-2"
formatter = FloatingFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
tokens = [("timeline_agent", ""), ("timeline_agent", "Summary")] events = [
frames = await collect(formatter, _stream(*tokens)) ("tool_end", {"name": "timeline_agent", "result": "done"}),
("token", "Summary"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain) assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "timelines" assert frames[0].domain == "timelines"
@@ -171,11 +166,12 @@ async def test_floating_formatter_no_block_frames():
"""FloatingFormatter must never emit WsStreamBlock.""" """FloatingFormatter must never emit WsStreamBlock."""
req_id = "pop-3" req_id = "pop-3"
formatter = FloatingFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
tokens = [ events = [
("note_agent", ""), ("tool_end", {"name": "note_agent", "result": "data"}),
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'), ("token", "some text"),
("mutations", []),
] ]
frames = await collect(formatter, _stream(*tokens)) frames = await collect(formatter, _stream(*events))
assert not any(isinstance(f, WsStreamBlock) for f in frames) assert not any(isinstance(f, WsStreamBlock) for f in frames)
@@ -183,13 +179,37 @@ async def test_floating_formatter_no_block_frames():
async def test_floating_formatter_end_frame(): async def test_floating_formatter_end_frame():
req_id = "pop-4" req_id = "pop-4"
formatter = FloatingFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done"))) events = [
("tool_end", {"name": "project_agent", "result": "ok"}),
("token", "Done"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_floating_formatter_unknown_agent_defaults_to_tasks(): async def test_floating_formatter_default_domain_on_early_token():
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
req_id = "pop-5" req_id = "pop-5"
formatter = FloatingFormatter(request_id=req_id) formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi"))) events = [("token", "hi"), ("mutations", [])]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks" assert frames[0].domain == "tasks"
@pytest.mark.asyncio
async def test_floating_formatter_mutations_in_stream_end():
req_id = "pop-6"
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
events = [
("token", "Updated"),
("mutations", muts),
]
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
end_frame = frames[-1]
assert isinstance(end_frame, WsStreamEnd)
assert len(end_frame.mutations) == 1

View File

@@ -88,7 +88,7 @@ class TestPluginRegistry:
async def test_list_filter_by_query( async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None: ) -> None:
result = await reg.list_plugins(db_session, query="time") result = await reg.list_plugins(db_session, query="time tracker")
assert result.total == 1 assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker" assert result.plugins[0].id == "plugin-time-tracker"

View File

@@ -45,14 +45,16 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
return frames return frames
async def _mock_home_stream(user_id, message, context, reg=None): async def _mock_home_stream(user_id, message, context, db_session_factory=None):
yield "task_agent", "" yield "tool_end", {"name": "task_agent", "result": "Found tasks"}
yield "task_agent", '{"type": "text", "content": "Hello"}' yield "token", "Hello"
yield "mutations", []
async def _mock_floating_stream(user_id, message, context, reg=None): async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
yield "task_agent", "" yield "tool_end", {"name": "task_agent", "result": "ok"}
yield "task_agent", "Here is a summary" yield "token", "Here is a summary"
yield "mutations", []
# ── tests ───────────────────────────────────────────────────────────────────── # ── tests ─────────────────────────────────────────────────────────────────────
@@ -61,7 +63,7 @@ def test_home_request_produces_stream_frames(client):
"""home_request → stream_start, stream_text+, stream_end.""" """home_request → stream_start, stream_text+, stream_end."""
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream): with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": [] "type": "device_hello", "device_id": "dev-1", "agent_ids": []
@@ -84,7 +86,7 @@ def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end.""" """floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream): with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": [] "type": "device_hello", "device_id": "dev-2", "agent_ids": []
@@ -112,11 +114,12 @@ def test_home_request_request_id_propagated(client):
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id" req_id = "my-unique-req-id"
async def _stream(user_id, message, context, reg=None): async def _stream(user_id, message, context, db_session_factory=None):
yield "note_agent", "" yield "tool_end", {"name": "note_agent", "result": "ok"}
yield "note_agent", '{"type": "text", "content": "ok"}' yield "token", "ok"
yield "mutations", []
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream): with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": [] "type": "device_hello", "device_id": "dev-3", "agent_ids": []