diff --git a/app/agents/__init__.py b/app/agents/__init__.py
index 6a202c1..8b2e848 100644
--- a/app/agents/__init__.py
+++ b/app/agents/__init__.py
@@ -1,4 +1,4 @@
-"""Import all agent modules to trigger @registry.register decorators."""
+"""Expose tool modules used by deep orchestrator-worker graphs."""
from app.agents import timeline_agent, note_agent, project_agent, task_agent
diff --git a/app/agents/note_agent.py b/app/agents/note_agent.py
index e5c648a..b8a6f18 100644
--- a/app/agents/note_agent.py
+++ b/app/agents/note_agent.py
@@ -2,17 +2,14 @@
from __future__ import annotations
-import json
from typing import Any
-from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
-from app.core.agent_registry import ChatAgent, registry
-from app.core.llm import embed, get_llm
+from app.core.llm import embed
from app.core.ws_context import execute_on_client
-_SYSTEM_PROMPT = (
+NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
@@ -122,23 +119,10 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted."
-@registry.register
-class NoteAgent(ChatAgent):
- def get_name(self) -> str:
- return "note_agent"
-
- def get_description(self) -> str:
- return "Manages notes: list, get, create, update, delete"
-
- def get_tools(self) -> list[Any]:
- return [list_notes, get_note, create_note, update_note, delete_note]
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- llm = get_llm()
- messages = [
- SystemMessage(content=_SYSTEM_PROMPT),
- HumanMessage(
- content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
- ),
- ]
- return await self._tool_loop(llm, messages, self.get_tools())
+NOTE_TOOLS: list[Any] = [
+ list_notes,
+ get_note,
+ create_note,
+ update_note,
+ delete_note,
+]
diff --git a/app/agents/project_agent.py b/app/agents/project_agent.py
index ccd2ea6..a07da0e 100644
--- a/app/agents/project_agent.py
+++ b/app/agents/project_agent.py
@@ -2,17 +2,13 @@
from __future__ import annotations
-import json
from typing import Any
-from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
-from app.core.agent_registry import ChatAgent, registry
-from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
-_SYSTEM_PROMPT = (
+PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
@@ -137,30 +133,11 @@ async def delete_project(project_id: str) -> str:
return f"Project {project_id} permanently deleted."
-@registry.register
-class ProjectAgent(ChatAgent):
- def get_name(self) -> str:
- return "project_agent"
-
- def get_description(self) -> str:
- return "Manages projects: list, get, create, update, archive, delete"
-
- def get_tools(self) -> list[Any]:
- return [
- list_projects,
- list_all_projects,
- get_project,
- create_project,
- update_project,
- delete_project,
- ]
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- llm = get_llm()
- messages = [
- SystemMessage(content=_SYSTEM_PROMPT),
- HumanMessage(
- content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
- ),
- ]
- return await self._tool_loop(llm, messages, self.get_tools())
+PROJECT_TOOLS: list[Any] = [
+ list_projects,
+ list_all_projects,
+ get_project,
+ create_project,
+ update_project,
+ delete_project,
+]
diff --git a/app/agents/task_agent.py b/app/agents/task_agent.py
index 1d6e32d..3f8ab95 100644
--- a/app/agents/task_agent.py
+++ b/app/agents/task_agent.py
@@ -2,18 +2,14 @@
from __future__ import annotations
-import json
from datetime import datetime, timezone
from typing import Any
-from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
-from app.core.agent_registry import ChatAgent, registry
-from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
-_SYSTEM_PROMPT = (
+TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
@@ -223,32 +219,13 @@ async def delete_task_comment(comment_id: str) -> str:
# ── Agent ─────────────────────────────────────────────────────────────
-@registry.register
-class TaskAgent(ChatAgent):
- def get_name(self) -> str:
- return "task_agent"
-
- def get_description(self) -> str:
- return "Manages tasks and comments: list, create, update, delete, due-today, comments"
-
- def get_tools(self) -> list[Any]:
- return [
- list_tasks,
- create_task,
- update_task,
- delete_task,
- list_tasks_due_today,
- list_task_comments,
- add_task_comment,
- delete_task_comment,
- ]
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- llm = get_llm()
- messages = [
- SystemMessage(content=_SYSTEM_PROMPT),
- HumanMessage(
- content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
- ),
- ]
- return await self._tool_loop(llm, messages, self.get_tools())
+TASK_TOOLS: list[Any] = [
+ list_tasks,
+ create_task,
+ update_task,
+ delete_task,
+ list_tasks_due_today,
+ list_task_comments,
+ add_task_comment,
+ delete_task_comment,
+]
diff --git a/app/agents/timeline_agent.py b/app/agents/timeline_agent.py
index 6e85357..19708e9 100644
--- a/app/agents/timeline_agent.py
+++ b/app/agents/timeline_agent.py
@@ -2,17 +2,13 @@
from __future__ import annotations
-import json
from typing import Any
-from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
-from app.core.agent_registry import ChatAgent, registry
-from app.core.llm import get_llm
from app.core.ws_context import execute_on_client
-_SYSTEM_PROMPT = (
+TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
@@ -106,23 +102,9 @@ async def delete_timeline(timeline_id: str) -> str:
return f"Timeline {timeline_id} deleted."
-@registry.register
-class TimelineAgent(ChatAgent):
- def get_name(self) -> str:
- return "timeline_agent"
-
- def get_description(self) -> str:
- return "Manages project timelines (milestones): list, create, update, delete"
-
- def get_tools(self) -> list[Any]:
- return [list_timelines, create_timeline, update_timeline, delete_timeline]
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- llm = get_llm()
- messages = [
- SystemMessage(content=_SYSTEM_PROMPT),
- HumanMessage(
- content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
- ),
- ]
- return await self._tool_loop(llm, messages, self.get_tools())
+TIMELINE_TOOLS: list[Any] = [
+ list_timelines,
+ create_timeline,
+ update_timeline,
+ delete_timeline,
+]
diff --git a/app/api/routes/chat.py b/app/api/routes/chat.py
index 1cd0fa4..6270d0e 100644
--- a/app/api/routes/chat.py
+++ b/app/api/routes/chat.py
@@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from app.api.deps import get_current_user
-from app.core.orchestrator import orchestrate
+from app.core.deep_agent import run_home
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
@@ -20,10 +20,10 @@ async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
- """Route a chat message through the orchestrator.
-
- Returns ``ChatResponse`` for ``execution_mode='direct'``,
- or ``ExecutionPlan`` for ``execution_mode='plan'``.
- """
- result = await orchestrate(body)
- return JSONResponse(content=result.model_dump())
+ """REST fallback for home chat when websocket streaming is unavailable."""
+ response = await run_home(
+ user_id=current_user.id,
+ message=body.message,
+ context=body.context.model_dump(),
+ )
+ return JSONResponse(content={"response": response})
diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py
index 771b696..1257e13 100644
--- a/app/api/routes/device_ws.py
+++ b/app/api/routes/device_ws.py
@@ -41,10 +41,10 @@ from sqlalchemy import update
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
+from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
-from app.core.orchestrator import orchestrate_v3_stream
-from app.core.output_formatter import HomeFormatter, FloatingFormatter
+from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
@@ -233,19 +233,10 @@ async def _handle_home_request(
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
- agent_holder: list = []
try:
- token_stream = orchestrate_v3_stream(
- user_id, message, context, agent_holder=agent_holder
- )
- formatter = HomeFormatter(request_id=request_id, tool_results=[])
- async for ws_frame in formatter.format(token_stream):
- # Inject mutations from agent tool_results into stream_end
- if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
- ws_frame.mutations = [ # type: ignore[union-attr]
- {"action": r["action"], "table": r["table"], "data": r["data"]}
- for r in getattr(agent_holder[0], "tool_results", [])
- ]
+ event_stream = run_home_stream(user_id, message, context)
+ formatter = StreamFormatter(request_id=request_id)
+ async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
@@ -287,18 +278,10 @@ async def _handle_floating_request(
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
- agent_holder: list = []
try:
- token_stream = orchestrate_v3_stream(
- user_id, message, context, agent_holder=agent_holder
- )
- formatter = FloatingFormatter(request_id=request_id)
- async for ws_frame in formatter.format(token_stream):
- if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
- ws_frame.mutations = [ # type: ignore[union-attr]
- {"action": r["action"], "table": r["table"], "data": r["data"]}
- for r in getattr(agent_holder[0], "tool_results", [])
- ]
+ event_stream = run_floating_stream(user_id, message, context)
+ formatter = StreamFormatter(request_id=request_id)
+ async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
diff --git a/app/api/routes/plans.py b/app/api/routes/plans.py
deleted file mode 100644
index ed27272..0000000
--- a/app/api/routes/plans.py
+++ /dev/null
@@ -1,37 +0,0 @@
-"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
-
-from __future__ import annotations
-
-from fastapi import APIRouter, Depends, HTTPException, status
-
-from app.api.deps import get_current_user
-from app.core.execution_plan import plan_cache
-from app.schemas import ExecutionPlan, UserProfile
-
-router = APIRouter(prefix="/plans", tags=["plans"])
-
-
-@router.get("/playbook", response_model=list[ExecutionPlan])
-async def list_playbooks(
- current_user: UserProfile = Depends(get_current_user),
-) -> list[ExecutionPlan]:
- """Return all cached execution plan playbooks for the authenticated user.
-
- TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
- """
- return plan_cache.get_all_playbooks()
-
-
-@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
-async def get_playbook(
- plan_id: str,
- current_user: UserProfile = Depends(get_current_user),
-) -> ExecutionPlan:
- """Return a specific execution plan playbook by ID."""
- plan = plan_cache.get_plan(plan_id)
- if plan is None:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail=f"Plan not found: {plan_id}",
- )
- return plan
diff --git a/app/core/agent_registry.py b/app/core/agent_registry.py
index 9a4930d..95c2033 100644
--- a/app/core/agent_registry.py
+++ b/app/core/agent_registry.py
@@ -1,14 +1,13 @@
-"""Agent Registry — base classes and singleton registry for chat agents."""
+"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
from abc import ABC, abstractmethod
-from collections.abc import AsyncGenerator
from typing import Any
class BaseAgent(ABC):
- """Common base for all agents."""
+ """Common base for non-chat agents still using the old base contract."""
def __init__(
self,
@@ -28,190 +27,4 @@ class BaseAgent(ABC):
@property
def skills(self) -> list[str]:
- """Override in subclasses to advertise capabilities."""
return []
-
-
-class ChatAgent(BaseAgent):
- """Base class for LLM-powered chat agents."""
-
- def __init__(self, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- # Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
- self.tool_results: list[dict] = []
-
- @abstractmethod
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- """Process a user query and return a text response."""
- ...
-
- async def handle_stream(
- self, query: str, context: dict[str, Any]
- ) -> AsyncGenerator[str, None]:
- """Streaming variant of handle().
-
- Default: calls handle() and yields the full response as one chunk.
- Override in subclasses for true token-level streaming via _tool_loop_stream.
- """
- yield await self.handle(query, context)
-
- @abstractmethod
- def get_tools(self) -> list[Any]:
- """Return LangChain tool definitions available to this agent."""
- ...
-
- async def _tool_loop(
- self,
- llm: Any,
- messages: list[Any],
- tools: list[Any],
- max_iter: int = 5,
- ) -> str:
- """Shared tool-calling loop.
-
- Binds *tools* to *llm*, invokes iteratively until the model stops
- requesting tool calls or *max_iter* is reached, and returns the
- final text response. Captures raw execute_on_client results in
- ``self.tool_results``.
- """
- from langchain_core.messages import AIMessage, ToolMessage
-
- from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
-
- collector: list[dict] = []
- set_tool_result_collector(collector)
- try:
- llm_with_tools = llm.bind_tools(tools) if tools else llm
-
- for _ in range(max_iter):
- response: AIMessage = await llm_with_tools.ainvoke(messages)
- messages.append(response)
-
- if not response.tool_calls:
- return str(response.content)
-
- # Execute each requested tool call
- tool_map = {t.name: t for t in tools}
- for call in response.tool_calls:
- tool_fn = tool_map.get(call["name"])
- if tool_fn is None:
- result = f"Unknown tool: {call['name']}"
- else:
- result = await tool_fn.ainvoke(call["args"])
- messages.append(
- ToolMessage(content=str(result), tool_call_id=call["id"])
- )
-
- # Exhausted iterations — ask model for a final answer without tools
- response = await llm.ainvoke(messages)
- return str(response.content)
- finally:
- clear_tool_result_collector()
- self.tool_results = collector
-
- async def _tool_loop_stream(
- self,
- llm: Any,
- messages: list[Any],
- tools: list[Any],
- max_iter: int = 5,
- ) -> AsyncGenerator[str, None]:
- """Streaming variant of ``_tool_loop``.
-
- Behaves identically for tool-calling iterations (uses ainvoke to parse
- tool calls). For the final response — when the model produces no further
- tool calls — switches to ``llm.astream()`` and yields text tokens.
- Captures raw execute_on_client results in ``self.tool_results``.
- """
- from langchain_core.messages import AIMessage, ToolMessage
-
- from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
-
- collector: list[dict] = []
- set_tool_result_collector(collector)
- try:
- llm_with_tools = llm.bind_tools(tools) if tools else llm
-
- for _ in range(max_iter):
- response: AIMessage = await llm_with_tools.ainvoke(messages)
-
- if not response.tool_calls:
- # Stream the final answer — don't keep the ainvoke result.
- async for chunk in llm.astream(messages):
- if chunk.content:
- yield str(chunk.content)
- return
-
- messages.append(response)
-
- # Execute each requested tool call
- tool_map = {t.name: t for t in tools}
- for call in response.tool_calls:
- tool_fn = tool_map.get(call["name"])
- if tool_fn is None:
- result = f"Unknown tool: {call['name']}"
- else:
- result = await tool_fn.ainvoke(call["args"])
- messages.append(
- ToolMessage(content=str(result), tool_call_id=call["id"])
- )
-
- # Exhausted iterations — stream a final answer without tools
- async for chunk in llm.astream(messages):
- if chunk.content:
- yield str(chunk.content)
- finally:
- clear_tool_result_collector()
- self.tool_results = collector
-
-
-class AgentRegistry:
- """Singleton registry for ChatAgent subclasses."""
-
- _instance: AgentRegistry | None = None
-
- def __init__(self) -> None:
- self._agents: dict[str, type[ChatAgent]] = {}
-
- def __new__(cls) -> AgentRegistry:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._agents = {}
- return cls._instance
-
- # ── public API ───────────────────────────────────────────────────
-
- def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
- """Class decorator — registers an agent by its name."""
- instance = agent_class()
- name = instance.get_name()
- self._agents[name] = agent_class
- return agent_class
-
- def get(self, name: str) -> ChatAgent:
- """Return a fresh instance of the named agent."""
- cls = self._agents.get(name)
- if cls is None:
- raise KeyError(f"Agent not found: {name}")
- return cls()
-
- def list_agents(self) -> list[dict[str, str]]:
- """Return ``[{name, description}]`` for the orchestrator prompt."""
- result: list[dict[str, str]] = []
- for cls in self._agents.values():
- inst = cls()
- result.append(
- {"name": inst.get_name(), "description": inst.get_description()}
- )
- return result
-
- async def call_agent(
- self, name: str, query: str, context: dict[str, Any]
- ) -> str:
- """Instantiate the named agent and call its ``handle`` method."""
- agent = self.get(name)
- return await agent.handle(query, context)
-
-
-# Module-level singleton
-registry = AgentRegistry()
diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py
new file mode 100644
index 0000000..d388ca4
--- /dev/null
+++ b/app/core/deep_agent.py
@@ -0,0 +1,576 @@
+"""Deep orchestrator-worker graphs for home and floating chat contexts."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import operator
+from collections.abc import AsyncGenerator, Awaitable, Callable
+from typing import Any, Literal, TypedDict
+
+from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
+from langchain_core.tools import tool
+from langgraph.constants import END, START
+from langgraph.graph import StateGraph
+from langgraph.types import Send
+from pydantic import BaseModel, Field
+
+from app.agents.note_agent import NOTE_SYSTEM_PROMPT, NOTE_TOOLS
+from app.agents.project_agent import PROJECT_SYSTEM_PROMPT, PROJECT_TOOLS
+from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
+from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
+from app.core.llm import get_llm
+from app.core.memory_middleware import MemoryMiddleware
+from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
+from app.db import async_session
+
+logger = logging.getLogger(__name__)
+
+WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
+FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
+
+
+class WorkerTask(BaseModel):
+ worker: WorkerName
+ instruction: str
+
+
+class WorkerPlan(BaseModel):
+ tasks: list[WorkerTask] = Field(default_factory=list)
+ floating_domain: FloatingDomain | None = None
+
+
+class WorkerResult(TypedDict):
+ worker: WorkerName
+ instruction: str
+ response: str
+ entity_ids: dict[str, list[str]]
+
+
+class OrchestratorState(TypedDict, total=False):
+ user_id: str
+ user_message: str
+ context: dict[str, Any]
+ memory_context: dict[str, Any]
+ plan: list[dict[str, Any]]
+ floating_domain: FloatingDomain
+ task: dict[str, Any]
+ worker_results: list[WorkerResult]
+ final_response: str
+ stream_callback: Callable[[str], Awaitable[None]] | None
+
+
+class GraphState(OrchestratorState):
+ worker_results: list[WorkerResult]
+
+
+class ReducerState(OrchestratorState):
+ worker_results: list[WorkerResult]
+
+
+class AggregatedState(TypedDict, total=False):
+ worker_results: list[WorkerResult]
+
+
+WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
+ "task_agent": {
+ "prompt": TASK_SYSTEM_PROMPT,
+ "tools": TASK_TOOLS,
+ "tag": "task",
+ "table": "tasks",
+ "floating_domain": "tasks",
+ },
+ "project_agent": {
+ "prompt": PROJECT_SYSTEM_PROMPT,
+ "tools": PROJECT_TOOLS,
+ "tag": "project",
+ "table": "projects",
+ "floating_domain": "projects",
+ },
+ "note_agent": {
+ "prompt": NOTE_SYSTEM_PROMPT,
+ "tools": NOTE_TOOLS,
+ "tag": "note",
+ "table": "notes",
+ "floating_domain": "notes",
+ },
+ "timeline_agent": {
+ "prompt": TIMELINE_SYSTEM_PROMPT,
+ "tools": TIMELINE_TOOLS,
+ "tag": "timeline",
+ "table": "timelines",
+ "floating_domain": "timelines",
+ },
+}
+
+_HOME_ORCHESTRATOR_SYSTEM = (
+ "You are an orchestrator. Plan which workers should be invoked for the user request. "
+ "Workers: task_agent, project_agent, note_agent, timeline_agent. "
+ "Return only the workers needed."
+)
+
+_FLOATING_ORCHESTRATOR_SYSTEM = (
+ "You are an orchestrator for floating context. Pick focused workers and set floating_domain "
+ "as one of: tasks, projects, notes, timelines."
+)
+
+_HOME_SYNTH_SYSTEM = (
+ "You are the final response synthesizer. Return markdown only. "
+ "Embed inline component tags when relevant: [ids], [ids], "
+ "[ids], [ids], and {json}. "
+ "Only include IDs that are truly relevant to the request."
+)
+
+_FLOATING_SYNTH_SYSTEM = (
+ "You are the final response synthesizer for floating UI context. "
+ "Return concise markdown and stay focused on the requested scope."
+)
+
+
+def _as_text(content: Any) -> str:
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts: list[str] = []
+ for item in content:
+ if isinstance(item, str):
+ parts.append(item)
+ elif isinstance(item, dict):
+ text = item.get("text")
+ if isinstance(text, str):
+ parts.append(text)
+ return "".join(parts)
+ return str(content)
+
+
+def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
+ lowered = message.lower()
+ tasks: list[WorkerTask] = []
+
+ if any(k in lowered for k in ["task", "todo", "deadline", "due"]):
+ tasks.append(WorkerTask(worker="task_agent", instruction=message))
+ if any(k in lowered for k in ["project", "client", "milestone"]):
+ tasks.append(WorkerTask(worker="project_agent", instruction=message))
+ if any(k in lowered for k in ["note", "document", "memo"]):
+ tasks.append(WorkerTask(worker="note_agent", instruction=message))
+ if any(k in lowered for k in ["timeline", "event", "schedule", "release"]):
+ tasks.append(WorkerTask(worker="timeline_agent", instruction=message))
+
+ if not tasks:
+ tasks = [WorkerTask(worker="task_agent", instruction=message)]
+
+ domain: FloatingDomain | None = None
+ if floating:
+ domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
+
+ return WorkerPlan(tasks=tasks, floating_domain=domain)
+
+
+async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
+ llm = get_llm()
+ system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
+
+ prompt_payload = {
+ "message": message,
+ "context": context,
+ "workers": list(WORKER_CONFIG.keys()),
+ }
+ messages = [
+ SystemMessage(content=system),
+ HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
+ ]
+
+ try:
+ structured_llm = llm.with_structured_output(WorkerPlan)
+ plan = await structured_llm.ainvoke(messages)
+ if isinstance(plan, WorkerPlan):
+ if not plan.tasks:
+ return _fallback_plan(message, floating)
+ return plan
+ except Exception as exc:
+ logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
+
+ return _fallback_plan(message, floating)
+
+
+def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[str]]:
+ out: dict[str, list[str]] = {
+ "task": [],
+ "project": [],
+ "note": [],
+ "timeline": [],
+ }
+ table_to_tag = {
+ "tasks": "task",
+ "projects": "project",
+ "notes": "note",
+ "timelines": "timeline",
+ }
+
+ for item in tool_results:
+ table = item.get("table")
+ tag = table_to_tag.get(table)
+ if tag is None:
+ continue
+
+ payload = item.get("data") or {}
+ rows: list[dict[str, Any]] = []
+ row = payload.get("row")
+ if isinstance(row, dict):
+ rows.append(row)
+ if isinstance(payload.get("rows"), list):
+ rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
+ if isinstance(payload.get("results"), list):
+ rows.extend([r for r in payload["results"] if isinstance(r, dict)])
+
+ for r in rows:
+ entity_id = r.get("id")
+ if isinstance(entity_id, str) and entity_id not in out[tag]:
+ out[tag].append(entity_id)
+
+ return out
+
+
+async def _run_tool_loop(
+ worker: WorkerName,
+ instruction: str,
+ context: dict[str, Any],
+) -> tuple[str, list[dict[str, Any]]]:
+ worker_prompt = WORKER_CONFIG[worker]["prompt"]
+ tools = WORKER_CONFIG[worker]["tools"]
+
+ llm = get_llm()
+ llm_with_tools = llm.bind_tools(tools) if tools else llm
+
+ messages: list[Any] = [
+ SystemMessage(content=worker_prompt),
+ HumanMessage(
+ content=(
+ "Worker instruction:\n"
+ f"{instruction}\n\n"
+ "Conversation context:\n"
+ f"{json.dumps(context, ensure_ascii=True)[:2000]}"
+ )
+ ),
+ ]
+
+ collected: list[dict[str, Any]] = []
+ set_tool_result_collector(collected)
+ try:
+ for _ in range(6):
+ response: AIMessage = await llm_with_tools.ainvoke(messages)
+ messages.append(response)
+
+ if not response.tool_calls:
+ return _as_text(response.content), collected
+
+ tool_map = {t.name: t for t in tools}
+ for call in response.tool_calls:
+ tool_fn = tool_map.get(call["name"])
+ if tool_fn is None:
+ tool_output = f"Unknown tool: {call['name']}"
+ else:
+ tool_output = await tool_fn.ainvoke(call.get("args", {}))
+ messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
+
+ final = await llm.ainvoke(messages)
+ return _as_text(final.content), collected
+ finally:
+ clear_tool_result_collector()
+
+
+def _worker_node(worker: WorkerName):
+ async def _node(state: GraphState) -> AggregatedState:
+ task_payload = state.get("task") or {}
+ if task_payload.get("worker") != worker:
+ return {"worker_results": []}
+
+ instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
+ worker_context = {
+ "memory": state.get("memory_context", {}),
+ "context": state.get("context", {}),
+ }
+ response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
+
+ return {
+ "worker_results": [
+ {
+ "worker": worker,
+ "instruction": instruction,
+ "response": response,
+ "entity_ids": _extract_entity_ids(tool_results),
+ }
+ ]
+ }
+
+ return _node
+
+
+def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
+ worker_results = state.get("worker_results", [])
+ formatted_results = []
+ for result in worker_results:
+ formatted_results.append(
+ {
+ "worker": result.get("worker"),
+ "instruction": result.get("instruction"),
+ "response": result.get("response"),
+ "entity_ids": result.get("entity_ids", {}),
+ }
+ )
+
+ payload = {
+ "user_message": state.get("user_message", ""),
+ "memory_context": state.get("memory_context", {}),
+ "worker_results": formatted_results,
+ "floating_domain": state.get("floating_domain") if floating else None,
+ }
+ return json.dumps(payload, ensure_ascii=True)
+
+
+async def _stream_with_memory_tool(
+ *,
+ user_id: str,
+ system_prompt: str,
+ user_prompt: str,
+ stream_callback: Callable[[str], Awaitable[None]] | None,
+) -> str:
+ @tool
+ async def update_core_memory(key: str, value: str) -> str:
+ """Save stable user preference/profile data to core memory."""
+ async with async_session() as db:
+ memory = MemoryMiddleware(db)
+ await memory.update_core(user_id, key, value)
+ return f"Saved core memory key '{key}'."
+
+ llm = get_llm()
+ messages: list[Any] = [
+ SystemMessage(content=system_prompt),
+ HumanMessage(content=user_prompt),
+ ]
+
+ llm_with_tools = llm.bind_tools([update_core_memory])
+
+ for _ in range(2):
+ response: AIMessage = await llm_with_tools.ainvoke(messages)
+ messages.append(response)
+
+ if not response.tool_calls:
+ break
+
+ for call in response.tool_calls:
+ if call["name"] != "update_core_memory":
+ messages.append(ToolMessage(content="Unsupported tool.", tool_call_id=call["id"]))
+ continue
+
+ tool_output = await update_core_memory.ainvoke(call.get("args", {}))
+ messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
+
+ chunks: list[str] = []
+ async for chunk in llm.astream(messages):
+ token = _as_text(getattr(chunk, "content", ""))
+ if not token:
+ continue
+ chunks.append(token)
+ if stream_callback is not None:
+ await stream_callback(token)
+
+ return "".join(chunks)
+
+
+def _synthesizer_node(floating: bool):
+ async def _node(state: GraphState) -> GraphState:
+ prompt = _build_synthesis_prompt(state, floating=floating)
+ system_prompt = _FLOATING_SYNTH_SYSTEM if floating else _HOME_SYNTH_SYSTEM
+
+ final_response = await _stream_with_memory_tool(
+ user_id=str(state.get("user_id", "")),
+ system_prompt=system_prompt,
+ user_prompt=prompt,
+ stream_callback=state.get("stream_callback"),
+ )
+
+ return {"final_response": final_response}
+
+ return _node
+
+
+async def _orchestrator_node_home(state: GraphState) -> GraphState:
+ if state.get("plan"):
+ return {}
+
+ context = {**state.get("context", {}), **state.get("memory_context", {})}
+ plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
+ return {"plan": [task.model_dump() for task in plan.tasks]}
+
+
+async def _orchestrator_node_floating(state: GraphState) -> GraphState:
+ if state.get("plan"):
+ return {}
+
+ context = {**state.get("context", {}), **state.get("memory_context", {})}
+ plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=True)
+ floating_domain = plan.floating_domain
+ if floating_domain is None and plan.tasks:
+ floating_domain = WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
+
+ return {
+ "plan": [task.model_dump() for task in plan.tasks],
+ "floating_domain": floating_domain or "tasks",
+ }
+
+
+def _route_workers(state: GraphState) -> list[Send] | str:
+ plan = state.get("plan", [])
+ if not plan:
+ return "synthesizer"
+
+ sends: list[Send] = []
+ for task in plan:
+ worker = task.get("worker")
+ if worker in WORKER_CONFIG:
+ sends.append(Send(worker, {"task": task}))
+
+ return sends or "synthesizer"
+
+
+def _build_graph(*, floating: bool):
+ builder = StateGraph(GraphState)
+
+ orchestrator_node = _orchestrator_node_floating if floating else _orchestrator_node_home
+ builder.add_node("orchestrator", orchestrator_node)
+ for worker in WORKER_CONFIG:
+ builder.add_node(worker, _worker_node(worker))
+ builder.add_node("synthesizer", _synthesizer_node(floating=floating))
+
+ builder.add_edge(START, "orchestrator")
+ builder.add_conditional_edges(
+ "orchestrator",
+ _route_workers,
+ ["task_agent", "project_agent", "note_agent", "timeline_agent", "synthesizer"],
+ )
+ for worker in WORKER_CONFIG:
+ builder.add_edge(worker, "synthesizer")
+ builder.add_edge("synthesizer", END)
+
+ return builder.compile()
+
+
+HOME_GRAPH = _build_graph(floating=False)
+FLOATING_GRAPH = _build_graph(floating=True)
+
+
+async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
+ state = await HOME_GRAPH.ainvoke(
+ {
+ "user_id": user_id,
+ "user_message": message,
+ "context": context,
+ "memory_context": context,
+ "worker_results": [],
+ "stream_callback": None,
+ }
+ )
+ return str(state.get("final_response", ""))
+
+
+async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, str]:
+ plan = await _plan_with_llm(message, context, floating=True)
+ domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
+
+ state = await FLOATING_GRAPH.ainvoke(
+ {
+ "user_id": user_id,
+ "user_message": message,
+ "context": context,
+ "memory_context": context,
+ "plan": [task.model_dump() for task in plan.tasks],
+ "floating_domain": domain,
+ "worker_results": [],
+ "stream_callback": None,
+ }
+ )
+ return str(state.get("final_response", "")), str(domain)
+
+
+async def run_home_stream(
+ user_id: str,
+ message: str,
+ context: dict[str, Any],
+) -> AsyncGenerator[tuple[str, Any], None]:
+ queue: asyncio.Queue[str] = asyncio.Queue()
+
+ async def _on_token(token: str) -> None:
+ await queue.put(token)
+
+ task = asyncio.create_task(
+ HOME_GRAPH.ainvoke(
+ {
+ "user_id": user_id,
+ "user_message": message,
+ "context": context,
+ "memory_context": context,
+ "worker_results": [],
+ "stream_callback": _on_token,
+ }
+ )
+ )
+
+ emitted = False
+ while not task.done() or not queue.empty():
+ try:
+ token = await asyncio.wait_for(queue.get(), timeout=0.15)
+ emitted = True
+ yield "token", token
+ except asyncio.TimeoutError:
+ continue
+
+ final_state = await task
+ if not emitted and final_state.get("final_response"):
+ yield "token", str(final_state["final_response"])
+
+
+async def run_floating_stream(
+ user_id: str,
+ message: str,
+ context: dict[str, Any],
+) -> AsyncGenerator[tuple[str, Any], None]:
+ plan = await _plan_with_llm(message, context, floating=True)
+ domain = plan.floating_domain or WORKER_CONFIG[plan.tasks[0].worker]["floating_domain"]
+ yield "floating_domain", domain
+
+ queue: asyncio.Queue[str] = asyncio.Queue()
+
+ async def _on_token(token: str) -> None:
+ await queue.put(token)
+
+ task = asyncio.create_task(
+ FLOATING_GRAPH.ainvoke(
+ {
+ "user_id": user_id,
+ "user_message": message,
+ "context": context,
+ "memory_context": context,
+ "plan": [t.model_dump() for t in plan.tasks],
+ "floating_domain": domain,
+ "worker_results": [],
+ "stream_callback": _on_token,
+ }
+ )
+ )
+
+ emitted = False
+ while not task.done() or not queue.empty():
+ try:
+ token = await asyncio.wait_for(queue.get(), timeout=0.15)
+ emitted = True
+ yield "token", token
+ except asyncio.TimeoutError:
+ continue
+
+ final_state = await task
+ if not emitted and final_state.get("final_response"):
+ yield "token", str(final_state["final_response"])
diff --git a/app/core/execution_plan.py b/app/core/execution_plan.py
deleted file mode 100644
index a98879f..0000000
--- a/app/core/execution_plan.py
+++ /dev/null
@@ -1,222 +0,0 @@
-"""Execution Plan generator — builder, template registry, and LRU plan cache."""
-
-from __future__ import annotations
-
-from collections import OrderedDict
-from typing import Any
-
-from app.schemas import ExecutionPlan, PlanStep
-
-
-# ── Prompt Template Registry ──────────────────────────────────────────
-
-
-class PromptTemplateRegistry:
- """Server-side store mapping template IDs to prompt text.
-
- Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
- The actual prompt text is resolved here on the server, keeping prompt IP
- out of API responses.
- """
-
- def __init__(self) -> None:
- self._templates: dict[str, str] = {}
-
- def register(self, template_id: str, prompt_text: str) -> None:
- self._templates[template_id] = prompt_text
-
- def get(self, template_id: str) -> str:
- """Resolve a template ID to its prompt text.
-
- Raises ``KeyError`` if the template is not registered.
- """
- text = self._templates.get(template_id)
- if text is None:
- raise KeyError(f"Template not found: {template_id!r}")
- return text
-
- def has(self, template_id: str) -> bool:
- return template_id in self._templates
-
- def list_ids(self) -> list[str]:
- """Return all registered template IDs (never the text)."""
- return list(self._templates.keys())
-
-
-# ── Execution Plan Builder ────────────────────────────────────────────
-
-
-class ExecutionPlanBuilder:
- """Fluent builder for ``ExecutionPlan`` objects.
-
- Example::
-
- plan = (
- ExecutionPlanBuilder("task_agent")
- .add_llm_step("tpl_task_agent_default", {"message": user_msg})
- .add_data_step("create_record", data_from_step=0)
- .build()
- )
- """
-
- def __init__(self, agent: str) -> None:
- self._agent = agent
- self._steps: list[PlanStep] = []
-
- # ── step adders ──────────────────────────────────────────────────
-
- def add_step(
- self, action: str, params: dict[str, Any] | None = None
- ) -> ExecutionPlanBuilder:
- """Append a generic action step with optional parameters."""
- self._steps.append(PlanStep(action=action, variables=params))
- return self
-
- def add_llm_step(
- self, template_id: str, variables: dict[str, Any] | None = None
- ) -> ExecutionPlanBuilder:
- """Append an LLM step referencing a server-side template by ID."""
- self._steps.append(
- PlanStep(action="llm", prompt_template=template_id, variables=variables)
- )
- return self
-
- def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
- """Append a step whose input comes from the output of an earlier step."""
- self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
- return self
-
- # ── build ────────────────────────────────────────────────────────
-
- def build(self) -> ExecutionPlan:
- """Validate step references and return the ``ExecutionPlan``.
-
- Raises ``ValueError`` if any ``data_from_step`` references a
- non-existent or future step index.
- """
- for i, step in enumerate(self._steps):
- if step.data_from_step is not None:
- if not (0 <= step.data_from_step < i):
- raise ValueError(
- f"Step {i}: data_from_step={step.data_from_step} must "
- f"reference a preceding step index in range 0..{i - 1}"
- )
- return ExecutionPlan(agent=self._agent, steps=list(self._steps))
-
-
-# ── Plan Cache (LRU) ──────────────────────────────────────────────────
-
-
-class PlanCache:
- """In-memory LRU cache for ``ExecutionPlan`` objects.
-
- Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
- The cache also serves as a runtime memoisation layer so that repeated
- identical intent classifications can skip re-building the plan.
- """
-
- def __init__(self, maxsize: int = 1000) -> None:
- self._maxsize = maxsize
- self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
-
- def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
- """Store *plan* under *key*, evicting the LRU entry if at capacity."""
- if key in self._cache:
- del self._cache[key] # remove so re-insertion places it at the end
- elif len(self._cache) >= self._maxsize:
- self._cache.popitem(last=False) # evict least-recently-used
- self._cache[key] = plan
-
- def get_plan(self, key: str) -> ExecutionPlan | None:
- """Return the cached plan for *key*, or ``None`` if not present.
-
- Accessing a plan marks it as most-recently used.
- """
- if key not in self._cache:
- return None
- self._cache.move_to_end(key)
- return self._cache[key]
-
- def get_all_playbooks(self) -> list[ExecutionPlan]:
- """Return all cached plans (most-recently used last)."""
- return list(self._cache.values())
-
-
-# ── Module-level singletons ───────────────────────────────────────────
-
-template_registry = PromptTemplateRegistry()
-plan_cache = PlanCache()
-
-
-def _register_builtin_templates() -> None:
- """Register the built-in server-side prompt templates.
-
- These strings never leave the server. Clients only receive the IDs.
- """
- _tpls: dict[str, str] = {
- "tpl_task_agent_default": (
- "You are a task management assistant. Help the user create, update, "
- "list, and track tasks. Use correct status values (todo, in_progress, "
- "done) and priority values (high, medium, low) from the workspace model."
- ),
- "tpl_timeline_agent_default": (
- "You are a project timeline assistant. Help the user create and manage "
- "milestone timelines on their projects. Every timeline requires a "
- "project_id and a date expressed as a Unix timestamp in milliseconds."
- ),
- "tpl_project_agent_default": (
- "You are a project management assistant. Help the user create, find, "
- "update, and archive projects. Projects have a name, an optional client, "
- "and a status of either active or archived."
- ),
- "tpl_note_agent_default": (
- "You are a note-taking assistant. Help the user create, retrieve, update, "
- "and delete Markdown notes. Notes can optionally be linked to a project."
- ),
- "tpl_task_extract_from_project": (
- "Extract all actionable tasks from the provided project context. "
- "Return a structured list of tasks, each with a title, inferred priority "
- "(high, medium, or low), suggested status (todo), and a due_date in "
- "milliseconds where a deadline can be inferred."
- ),
- "tpl_note_weekly_summary": (
- "Generate a weekly project summary note from the provided workspace data. "
- "Include: tasks completed this week, tasks due soon, active projects, "
- "and upcoming timelines. Format the output as clean Markdown."
- ),
- }
- for tid, text in _tpls.items():
- template_registry.register(tid, text)
-
-
-def _load_playbooks() -> None:
- """Pre-build and cache the built-in playbooks."""
- playbooks: list[tuple[str, ExecutionPlan]] = [
- (
- "create_tasks_from_project",
- ExecutionPlanBuilder("project_agent")
- .add_llm_step(
- "tpl_task_extract_from_project",
- {"source": "project_context"},
- )
- .add_data_step("create_record", data_from_step=0)
- .build(),
- ),
- (
- "generate_weekly_note",
- ExecutionPlanBuilder("note_agent")
- .add_llm_step(
- "tpl_note_weekly_summary",
- {"period": "last_7_days"},
- )
- .add_data_step("create_record", data_from_step=0)
- .build(),
- ),
- ]
- for key, plan in playbooks:
- plan_cache.cache_plan(key, plan)
-
-
-# Initialise on module load
-_register_builtin_templates()
-_load_playbooks()
diff --git a/app/core/orchestrator.py b/app/core/orchestrator.py
deleted file mode 100644
index 7765704..0000000
--- a/app/core/orchestrator.py
+++ /dev/null
@@ -1,210 +0,0 @@
-"""Orchestrator — LLM-based intent router and agent pipeline."""
-
-from __future__ import annotations
-
-import json
-from typing import Any, AsyncGenerator
-
-from langchain_core.messages import HumanMessage, SystemMessage
-
-from app.core.agent_registry import AgentRegistry, ChatAgent
-from app.core.llm import get_router_llm
-from app.core.agent_registry import registry as _default_registry
-from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
-
-_FALLBACK_AGENT = "task_agent"
-
-_CLASSIFY_SYSTEM = (
- "You are an intent classifier. Given the user message and context, decide "
- "which agent to route to.\n"
- "Available agents: {agents}\n"
- "Respond with just the agent name, nothing else."
-)
-
-_SYNTHESIZE_HUMAN = (
- "Combine the following agent results into one coherent response.\n\n"
- "Agent results:\n{results}\n\n"
- "Original message: {message}"
-)
-
-
-def _make_llm():
- return get_router_llm()
-
-
-async def classify_intent(
- message: str,
- context: dict[str, Any],
- reg: AgentRegistry,
-) -> str:
- """Use gpt-4o-mini to classify intent and return the matching agent name.
-
- Falls back to ``task_agent`` when the registry is empty or the model
- returns a name that is not registered.
- """
- agents = reg.list_agents()
- if not agents:
- return _FALLBACK_AGENT
-
- system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
- # Truncate context to keep the classification prompt short
- human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
-
- llm = _make_llm()
- response = await llm.ainvoke(
- [SystemMessage(content=system), HumanMessage(content=human)]
- )
-
- agent_name = str(response.content).strip().lower()
- known = {a["name"] for a in agents}
- return agent_name if agent_name in known else _FALLBACK_AGENT
-
-
-async def route_single(
- agent_name: str,
- message: str,
- context: dict[str, Any],
- reg: AgentRegistry,
-) -> ChatResponse:
- """Route to a single agent and wrap the result in a ``ChatResponse``."""
- response_text = await reg.call_agent(agent_name, message, context)
- return ChatResponse(response=response_text)
-
-
-async def route_pipeline(
- agent_names: list[str],
- message: str,
- context: dict[str, Any],
- reg: AgentRegistry,
-) -> ChatResponse:
- """Execute agents sequentially; each agent receives previous results in context.
-
- A final LLM synthesis call merges all results into one coherent response.
- """
- previous_results: list[str] = []
-
- for agent_name in agent_names:
- ctx = {**context, "previous_results": list(previous_results)}
- result = await reg.call_agent(agent_name, message, ctx)
- previous_results.append(result)
-
- results_str = "\n\n".join(
- f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
- )
- human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
- llm = _make_llm()
- synthesis = await llm.ainvoke([HumanMessage(content=human)])
- return ChatResponse(response=str(synthesis.content))
-
-
-def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
- """Build an ``ExecutionPlan`` for the resolved agent.
-
- Uses ``ExecutionPlanBuilder`` with the server-side template registry.
- If a default template exists for the agent, an LLM step is emitted;
- otherwise a plain ``handle`` action step is used.
- """
- from app.core.execution_plan import ExecutionPlanBuilder, template_registry
-
- template_id = f"tpl_{agent_name}_default"
- builder = ExecutionPlanBuilder(agent_name)
- if template_registry.has(template_id):
- builder.add_llm_step(template_id, {"message": message})
- else:
- builder.add_step("handle", {"message": message})
- return builder.build()
-
-
-async def orchestrate(
- request: ChatRequest,
- reg: AgentRegistry | None = None,
-) -> ChatResponse | ExecutionPlan:
- """Main orchestration entry point.
-
- * Classifies the user's intent to select an agent.
- * ``execution_mode == 'direct'``: routes to the agent and returns a
- ``ChatResponse``.
- * ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
- resolved agent and a template-ID-only step (prompt IP stays server-side).
- """
- if reg is None:
- reg = _default_registry
-
- context = request.context.model_dump()
- agent_name = await classify_intent(request.message, context, reg)
-
- if request.execution_mode == "direct":
- return await route_single(agent_name, request.message, context, reg)
-
- # plan mode — return plan, do not execute
- return _build_plan(agent_name, request.message)
-
-
-async def orchestrate_v3(
- user_id: str,
- message: str,
- context: dict[str, Any],
- reg: AgentRegistry | None = None,
-) -> tuple[str, ChatAgent]:
- """v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
-
- Classifies intent and instantiates the matching agent. The caller is responsible
- for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
- """
- if reg is None:
- reg = _default_registry
- agent_name = await classify_intent(message, context, reg)
- return agent_name, reg.get(agent_name)
-
-
-async def orchestrate_v3_stream(
- user_id: str,
- message: str,
- context: dict[str, Any],
- reg: AgentRegistry | None = None,
- agent_holder: list | None = None,
-) -> AsyncGenerator[tuple[str, str], None]:
- """v3 streaming orchestration — yields (agent_name, token) pairs.
-
- The first yield always carries the agent_name with an empty token so that
- callers (e.g. FloatingFormatter) can detect the routing domain before any text
- tokens arrive.
-
- If *agent_holder* is provided (a list), the agent instance is appended so
- callers can access ``agent.tool_results`` after the stream completes.
- """
- if reg is None:
- reg = _default_registry
- agent_name = await classify_intent(message, context, reg)
- agent = reg.get(agent_name)
- if agent_holder is not None:
- agent_holder.append(agent)
- yield agent_name, "" # domain signal — no token yet
- async for token in agent.handle_stream(message, context):
- yield agent_name, token
-
-
-async def orchestrate_stream(
- request: ChatRequest,
- reg: AgentRegistry | None = None,
-) -> AsyncGenerator[str, None]:
- """Streaming orchestration — yields plain text chunks only.
-
- The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
- wrapping each chunk in a ``text_chunk`` frame and sending the final
- ``final`` frame once the generator is exhausted.
-
- Agents do not yet support token-level streaming; the full response is
- fetched first (which may involve multiple WS round-trips for tool calls),
- then emitted in fixed-size chunks.
- """
- if reg is None:
- reg = _default_registry
-
- context = request.context.model_dump()
- agent_name = await classify_intent(request.message, context, reg)
- response_text = await reg.call_agent(agent_name, request.message, context)
-
- chunk_size = 50
- for i in range(0, len(response_text), chunk_size):
- yield response_text[i : i + chunk_size]
diff --git a/app/core/output_formatter.py b/app/core/output_formatter.py
index a8e44fb..429a2ce 100644
--- a/app/core/output_formatter.py
+++ b/app/core/output_formatter.py
@@ -1,244 +1,43 @@
-"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
-
-HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
-FloatingFormatter: produces floating_domain, stream_text, stream_end
-"""
+"""Output formatter for deep-agent stream events."""
from __future__ import annotations
-import json
-import logging
from collections.abc import AsyncGenerator
from typing import Any
-from app.schemas import (
- WsFloatingDomain,
- WsStreamBlock,
- WsStreamEnd,
- WsStreamStart,
- WsStreamText,
-)
+from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
-logger = logging.getLogger(__name__)
-
-# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
-_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
-
-# Map agent name → floating domain
-_AGENT_DOMAIN: dict[str, str] = {
- "task_agent": "tasks",
- "timeline_agent": "timelines",
- "note_agent": "notes",
- "project_agent": "projects",
-}
-
-WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatingDomain
+WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
-class HomeFormatter:
- """Parses a token stream from orchestrate_v3_stream and yields WS frames.
-
- The LLM is expected to output a newline-delimited sequence of JSON objects,
- each with a ``type`` field:
- - ``text`` → yields WsStreamText immediately (word-by-word)
- - ``chart`` → buffers full JSON, validates, yields WsStreamBlock
- - ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
- - ``table`` → buffers full JSON, validates, yields WsStreamBlock
- - ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
-
- Invalid or unknown blocks are logged and skipped — stream never crashes.
- """
-
- def __init__(self, request_id: str, tool_results: list[dict]) -> None:
- self.request_id = request_id
- self.tool_results = tool_results
-
- async def format(
- self,
- token_stream: AsyncGenerator[tuple[str, str], None],
- ) -> AsyncGenerator[WsFrame, None]:
- yield WsStreamStart(request_id=self.request_id)
-
- buffer = ""
- async for _agent_name, token in token_stream:
- if not token:
- continue
- buffer += token
- # Flush any complete JSON objects from the buffer
- async for frame in self._flush_complete_objects(buffer):
- buffer = "" # reset after flush
- yield frame
- break # only one flush per iteration; rest accumulates
-
- # Flush any remaining content
- if buffer.strip():
- async for frame in self._flush_complete_objects(buffer, final=True):
- yield frame
-
- yield WsStreamEnd(request_id=self.request_id)
-
- async def _flush_complete_objects(
- self, text: str, final: bool = False
- ) -> AsyncGenerator[WsFrame, None]:
- """Try to parse and yield all complete JSON objects from *text*.
-
- Yields nothing if text is incomplete JSON (unless *final* is True,
- in which case remaining text is emitted as plain stream_text).
- """
- remaining = text.strip()
- while remaining:
- # Fast path: plain text (not JSON)
- if not remaining.startswith("{"):
- # Yield as plain text chunk
- newline_idx = remaining.find("\n")
- if newline_idx == -1:
- if final:
- yield WsStreamText(request_id=self.request_id, chunk=remaining)
- remaining = ""
- else:
- return # accumulate more
- else:
- line = remaining[:newline_idx].strip()
- remaining = remaining[newline_idx + 1:].strip()
- if line:
- yield WsStreamText(request_id=self.request_id, chunk=line)
- continue
-
- # Try to decode a JSON object
- try:
- obj, end_idx = _try_parse_json(remaining)
- except ValueError:
- if final:
- # Emit as raw text if we can't parse
- yield WsStreamText(request_id=self.request_id, chunk=remaining)
- remaining = ""
- return
-
- if obj is None:
- if final:
- yield WsStreamText(request_id=self.request_id, chunk=remaining)
- remaining = ""
- return # incomplete — need more tokens
-
- remaining = remaining[end_idx:].strip()
- block_type = obj.get("type")
-
- frame = self._dispatch_block(obj, block_type)
- if frame is not None:
- yield frame
-
- def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
- if block_type == "text":
- content = obj.get("content", "")
- if content:
- return WsStreamText(request_id=self.request_id, chunk=str(content))
- return None
-
- if block_type == "chart":
- chart_type = obj.get("chartType")
- if chart_type not in _VALID_CHART_TYPES:
- logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
- return None
- if not isinstance(obj.get("data"), list):
- logger.warning("HomeFormatter: chart missing data array — skipping")
- return None
- return WsStreamBlock(
- request_id=self.request_id,
- block_type="chart",
- data=obj,
- )
-
- if block_type == "entity_ref":
- entity = obj.get("entity")
- resolved = self._resolve_entity(entity)
- if resolved is None:
- logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
- return None
- return WsStreamBlock(
- request_id=self.request_id,
- block_type="entity_ref",
- data={"entity": entity, "items": resolved},
- )
-
- if block_type == "table":
- if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
- logger.warning("HomeFormatter: table missing headers/rows — skipping")
- return None
- return WsStreamBlock(
- request_id=self.request_id,
- block_type="table",
- data=obj,
- )
-
- if block_type == "timeline":
- if not isinstance(obj.get("timelines"), list):
- logger.warning("HomeFormatter: timeline missing timelines — skipping")
- return None
- return WsStreamBlock(
- request_id=self.request_id,
- block_type="timeline",
- data=obj,
- )
-
- logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
- return None
-
- def _resolve_entity(self, entity: str | None) -> list[dict] | None:
- """Find matching items in tool_results by entity type."""
- if not entity:
- return None
- matches = [r for r in self.tool_results if r.get("entity") == entity]
- return matches if matches else None
-
-
-class FloatingFormatter:
- """Parses a token stream from orchestrate_v3_stream and yields WS frames.
-
- Emits floating_domain immediately (from agent_name), then streams all tokens
- as plain stream_text — no block parsing for floating context.
- """
+class StreamFormatter:
+ """Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
- token_stream: AsyncGenerator[tuple[str, str], None],
+ event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
- domain_sent = False
+ started = False
- async for agent_name, token in token_stream:
- if not domain_sent:
- domain = _AGENT_DOMAIN.get(agent_name, "tasks")
- yield WsFloatingDomain(
- request_id=self.request_id,
- domain=domain, # type: ignore[arg-type]
- )
+ async for event_type, data in event_stream:
+ if event_type == "floating_domain":
+ yield WsFloatingDomain(request_id=self.request_id, domain=str(data))
+ continue
+
+ if event_type != "token":
+ continue
+
+ if not started:
yield WsStreamStart(request_id=self.request_id)
- domain_sent = True
+ started = True
- if token:
- yield WsStreamText(request_id=self.request_id, chunk=token)
+ text = str(data or "")
+ if text:
+ yield WsStreamText(request_id=self.request_id, chunk=text)
+ if not started:
+ yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id)
-
-
-# ── helpers ───────────────────────────────────────────────────────────────────
-
-def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
- """Attempt to parse the first complete JSON object from *text*.
-
- Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
- object is incomplete, and raises ``ValueError`` when text is not JSON.
- """
- decoder = json.JSONDecoder()
- try:
- obj, end_idx = decoder.raw_decode(text)
- if not isinstance(obj, dict):
- raise ValueError("Expected JSON object")
- return obj, end_idx
- except json.JSONDecodeError as exc:
- # Incomplete JSON — need more tokens
- if "Unterminated" in str(exc) or exc.pos == len(text):
- return None, 0
- raise ValueError(str(exc)) from exc
diff --git a/app/main.py b/app/main.py
index 74c25ee..957512b 100644
--- a/app/main.py
+++ b/app/main.py
@@ -18,9 +18,8 @@ from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
- # Startup: initialise DB connection pool and agent registry
- from app.core.agent_registry import registry # noqa: F401 — triggers module load
- import app.agents # noqa: F401 — triggers @registry.register decorators
+ # Startup: ensure agent tool modules are loaded.
+ import app.agents # noqa: F401
yield
@@ -51,11 +50,10 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
- from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
+ from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
- app.include_router(plans.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
diff --git a/app/schemas.py b/app/schemas.py
index f3a281b..3005169 100644
--- a/app/schemas.py
+++ b/app/schemas.py
@@ -41,41 +41,13 @@ class ChatContext(BaseModel):
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
-class PlanAction(BaseModel):
- type: Literal[
- "create_record",
- "update_record",
- "delete_record",
- "index_document",
- "send_notification",
- ]
- table: str | None = None
- data: dict[str, Any] | None = None
-
-
class ChatRequest(BaseModel):
message: str
context: ChatContext = Field(default_factory=ChatContext)
- execution_mode: Literal["direct", "plan"] = "direct"
class ChatResponse(BaseModel):
response: str
- actions: list[PlanAction] = Field(default_factory=list)
-
-
-# ── Execution Plans ──────────────────────────────────────────────────
-
-class PlanStep(BaseModel):
- action: str
- prompt_template: str | None = None
- variables: dict[str, Any] | None = None
- data_from_step: int | None = None
-
-
-class ExecutionPlan(BaseModel):
- agent: str
- steps: list[PlanStep] = Field(default_factory=list)
# ── Backup ───────────────────────────────────────────────────────────
@@ -179,7 +151,6 @@ class WsFrameType(str, Enum):
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
- stream_block = "stream_block"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
@@ -303,21 +274,11 @@ class WsStreamText(BaseModel):
chunk: str
-class WsStreamBlock(BaseModel):
- """Server → Client: structured block (chart, table, entity, timeline)."""
-
- type: Literal[WsFrameType.stream_block] = WsFrameType.stream_block
- request_id: str
- block_type: Literal["chart", "entity_ref", "table", "timeline"]
- data: dict[str, Any]
-
-
class WsStreamEnd(BaseModel):
"""Server → Client: signals end of a streaming response."""
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
- mutations: list[dict[str, Any]] = Field(default_factory=list)
class WsFloatingDomain(BaseModel):
diff --git a/requirements.txt b/requirements.txt
index ea10f59..8202519 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,7 @@ langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
litellm>=1.50.0
+langgraph>=0.4.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py
deleted file mode 100644
index 9fd9381..0000000
--- a/tests/test_agent_registry.py
+++ /dev/null
@@ -1,214 +0,0 @@
-"""Unit tests for the agent registry, base classes, and tool loop."""
-
-from __future__ import annotations
-
-from typing import Any
-from unittest.mock import AsyncMock, MagicMock
-
-import pytest
-
-from app.core.agent_registry import AgentRegistry, ChatAgent
-
-
-# ── Helpers ──────────────────────────────────────────────────────────
-
-class _StubAgent(ChatAgent):
- """Minimal concrete agent for testing."""
-
- def get_name(self) -> str:
- return "stub"
-
- def get_description(self) -> str:
- return "A stub agent for tests"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return f"echo: {query}"
-
-
-class _AnotherAgent(ChatAgent):
- def get_name(self) -> str:
- return "another"
-
- def get_description(self) -> str:
- return "Another stub"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return "another"
-
-
-# ── Fixtures ─────────────────────────────────────────────────────────
-
-@pytest.fixture(autouse=True)
-def _fresh_registry():
- """Reset the singleton between tests."""
- AgentRegistry._instance = None
- yield
- AgentRegistry._instance = None
-
-
-@pytest.fixture()
-def reg() -> AgentRegistry:
- return AgentRegistry()
-
-
-# ── Tests ────────────────────────────────────────────────────────────
-
-class TestRegisterAndGet:
- def test_register_decorator(self, reg: AgentRegistry) -> None:
- reg.register(_StubAgent)
- agent = reg.get("stub")
- assert isinstance(agent, _StubAgent)
-
- def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
- with pytest.raises(KeyError, match="not found"):
- reg.get("nonexistent")
-
- def test_register_multiple(self, reg: AgentRegistry) -> None:
- reg.register(_StubAgent)
- reg.register(_AnotherAgent)
- assert reg.get("stub").get_name() == "stub"
- assert reg.get("another").get_name() == "another"
-
-
-class TestListAgents:
- def test_empty(self, reg: AgentRegistry) -> None:
- assert reg.list_agents() == []
-
- def test_list_after_register(self, reg: AgentRegistry) -> None:
- reg.register(_StubAgent)
- agents = reg.list_agents()
- assert len(agents) == 1
- assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
-
- def test_list_multiple(self, reg: AgentRegistry) -> None:
- reg.register(_StubAgent)
- reg.register(_AnotherAgent)
- names = {a["name"] for a in reg.list_agents()}
- assert names == {"stub", "another"}
-
-
-class TestCallAgent:
- @pytest.mark.asyncio
- async def test_call_agent(self, reg: AgentRegistry) -> None:
- reg.register(_StubAgent)
- result = await reg.call_agent("stub", "hello", {})
- assert result == "echo: hello"
-
- @pytest.mark.asyncio
- async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
- with pytest.raises(KeyError):
- await reg.call_agent("nope", "hi", {})
-
-
-class TestSingleton:
- def test_singleton_identity(self) -> None:
- a = AgentRegistry()
- b = AgentRegistry()
- assert a is b
-
-
-class TestToolLoop:
- @pytest.mark.asyncio
- async def test_no_tool_calls(self) -> None:
- """When the LLM responds without tool calls, return content directly."""
- agent = _StubAgent()
-
- ai_msg = MagicMock()
- ai_msg.content = "final answer"
- ai_msg.tool_calls = []
-
- llm = AsyncMock()
- llm.bind_tools = MagicMock(return_value=llm)
- llm.ainvoke = AsyncMock(return_value=ai_msg)
-
- result = await agent._tool_loop(llm, [], [])
- assert result == "final answer"
-
- @pytest.mark.asyncio
- async def test_tool_call_then_answer(self) -> None:
- """LLM requests one tool call, gets result, then answers."""
- agent = _StubAgent()
-
- # First response: tool call
- tool_call_msg = MagicMock()
- tool_call_msg.content = ""
- tool_call_msg.tool_calls = [
- {"id": "call_1", "name": "my_tool", "args": {"x": 1}}
- ]
-
- # Second response: final answer
- final_msg = MagicMock()
- final_msg.content = "done"
- final_msg.tool_calls = []
-
- llm = AsyncMock()
- llm.bind_tools = MagicMock(return_value=llm)
- llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
-
- # Mock tool
- tool = AsyncMock()
- tool.name = "my_tool"
- tool.ainvoke = AsyncMock(return_value="tool_result")
-
- result = await agent._tool_loop(llm, [], [tool])
- assert result == "done"
- tool.ainvoke.assert_called_once_with({"x": 1})
-
- @pytest.mark.asyncio
- async def test_unknown_tool_handled(self) -> None:
- """Unknown tool names produce an error message instead of crashing."""
- agent = _StubAgent()
-
- tool_call_msg = MagicMock()
- tool_call_msg.content = ""
- tool_call_msg.tool_calls = [
- {"id": "call_1", "name": "missing", "args": {}}
- ]
-
- final_msg = MagicMock()
- final_msg.content = "recovered"
- final_msg.tool_calls = []
-
- llm = AsyncMock()
- llm.bind_tools = MagicMock(return_value=llm)
- llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
-
- result = await agent._tool_loop(llm, [], [])
- assert result == "recovered"
-
- @pytest.mark.asyncio
- async def test_max_iter_reached(self) -> None:
- """When max iterations are exhausted, a final no-tools call is made."""
- agent = _StubAgent()
-
- # Every response requests a tool call
- loop_msg = MagicMock()
- loop_msg.content = ""
- loop_msg.tool_calls = [
- {"id": "call_x", "name": "t", "args": {}}
- ]
-
- final_msg = MagicMock()
- final_msg.content = "gave up"
- final_msg.tool_calls = []
-
- tool = AsyncMock()
- tool.name = "t"
- tool.ainvoke = AsyncMock(return_value="ok")
-
- llm_with_tools = AsyncMock()
- llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
-
- llm = AsyncMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- llm.ainvoke = AsyncMock(return_value=final_msg)
-
- result = await agent._tool_loop(llm, [], [tool], max_iter=2)
- assert result == "gave up"
- assert llm_with_tools.ainvoke.call_count == 2
diff --git a/tests/test_agent_streaming.py b/tests/test_agent_streaming.py
deleted file mode 100644
index 59a8232..0000000
--- a/tests/test_agent_streaming.py
+++ /dev/null
@@ -1,416 +0,0 @@
-"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
-
-from __future__ import annotations
-
-import pytest
-from unittest.mock import AsyncMock, MagicMock, patch
-from typing import Any
-
-from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
-
-from app.core.agent_registry import ChatAgent, registry
-
-
-# ── Minimal concrete agent for testing ───────────────────────────────
-
-
-class _EchoAgent(ChatAgent):
- def get_name(self) -> str:
- return "_echo"
-
- def get_description(self) -> str:
- return "Echo agent for tests"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return query
-
-
-# ── Helpers ───────────────────────────────────────────────────────────
-
-
-def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
- msg = AIMessage(content=content)
- if tool_calls:
- msg.tool_calls = tool_calls
- else:
- msg.tool_calls = []
- return msg
-
-
-def _make_tool(name: str, return_value: Any) -> MagicMock:
- t = MagicMock()
- t.name = name
- t.ainvoke = AsyncMock(return_value=return_value)
- return t
-
-
-def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
- chunks = []
- for tok in tokens:
- c = MagicMock()
- c.content = tok
- chunks.append(c)
- return chunks
-
-
-async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
- tokens: list[str] = []
- async for tok in agent._tool_loop_stream(llm, messages, tools):
- tokens.append(tok)
- return tokens
-
-
-# ── tool_results initialised ─────────────────────────────────────────
-
-
-def test_tool_results_init():
- agent = _EchoAgent()
- assert agent.tool_results == []
-
-
-# ── _tool_loop: no tool calls ────────────────────────────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_no_tools():
- agent = _EchoAgent()
- llm = AsyncMock()
- llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
-
- result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
- assert result == "Hello!"
- assert agent.tool_results == []
-
-
-# ── _tool_loop: with one tool call + result capture ──────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_captures_tool_results():
- agent = _EchoAgent()
-
- # Mock execute_on_client to return structured data via the tool
- raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
-
- async def fake_executor(payload: dict) -> dict:
- return raw_result
-
- # AIMessage with a tool call, then a final answer
- tool_call_msg = _make_ai_message(
- tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
- )
- final_msg = _make_ai_message("Here are your tasks.")
-
- llm = MagicMock()
- llm_with_tools = MagicMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
- llm.ainvoke = AsyncMock(return_value=final_msg)
-
- mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
-
- from app.core.ws_context import set_client_executor, clear_client_executor
- set_client_executor(fake_executor)
- try:
- # Patch the tool to actually call execute_on_client
- async def tool_side_effect(args: dict) -> str:
- from app.core.ws_context import execute_on_client
- res = await execute_on_client(action="select", table="tasks")
- rows = res.get("rows", [])
- return "\n".join(r["title"] for r in rows)
-
- mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
-
- result = await agent._tool_loop(
- llm, [HumanMessage(content="list my tasks")], [mock_tool]
- )
- finally:
- clear_client_executor()
-
- assert result == "Here are your tasks."
- assert len(agent.tool_results) == 1
- assert agent.tool_results[0] == raw_result
-
-
-# ── _tool_loop: tool_results reset on each call ──────────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_resets_tool_results():
- agent = _EchoAgent()
- agent.tool_results = [{"stale": True}] # pre-populated from a previous call
-
- llm = AsyncMock()
- llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
-
- await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
- assert agent.tool_results == []
-
-
-# ── _tool_loop: unknown tool name ────────────────────────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_unknown_tool():
- agent = _EchoAgent()
-
- # No known tools — model still calls a non-existent one; loop handles gracefully
- tool_call_msg = _make_ai_message(
- tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
- )
- final_msg = _make_ai_message("Handled.")
-
- mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
- llm = MagicMock()
- llm_with_tools = MagicMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
-
- result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
- assert result == "Handled."
-
-
-# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_max_iter():
- agent = _EchoAgent()
-
- always_tool = _make_ai_message(
- tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
- )
- fallback = _make_ai_message("Fallback.")
-
- llm = MagicMock()
- llm_with_tools = MagicMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- # Returns tool_call_msg on every iteration
- llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
- llm.ainvoke = AsyncMock(return_value=fallback)
-
- mock_tool = _make_tool("t", "ok")
-
- result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
- assert result == "Fallback."
- assert llm_with_tools.ainvoke.call_count == 2
-
-
-# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_stream_no_tools_yields_tokens():
- agent = _EchoAgent()
-
- # No tools → llm used directly; ainvoke returns no tool calls → stream is used
- no_tool_msg = _make_ai_message("irrelevant")
- llm = AsyncMock()
- llm.ainvoke = AsyncMock(return_value=no_tool_msg)
-
- async def fake_astream(msgs):
- for tok in ["Hello", " ", "world"]:
- c = MagicMock()
- c.content = tok
- yield c
-
- llm.astream = fake_astream
-
- tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
- assert tokens == ["Hello", " ", "world"]
- assert agent.tool_results == []
-
-
-# ── _tool_loop_stream: one tool call then streaming final ─────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_stream_with_tool_call():
- agent = _EchoAgent()
-
- raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
-
- async def fake_executor(payload: dict) -> dict:
- return raw_result
-
- tool_call_msg = _make_ai_message(
- tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
- )
- # After tools run, ainvoke returns no more tool calls
- no_more_tools_msg = _make_ai_message("Task found.")
-
- llm = MagicMock()
- llm_with_tools = MagicMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
-
- async def fake_astream(msgs):
- for tok in ["Task", " ", "found."]:
- c = MagicMock()
- c.content = tok
- yield c
-
- llm.astream = fake_astream
-
- async def tool_side_effect(args: dict) -> str:
- from app.core.ws_context import execute_on_client
- res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
- return res.get("row", {}).get("title", "")
-
- mock_tool = _make_tool("get_task", "Deploy")
- mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
-
- from app.core.ws_context import set_client_executor, clear_client_executor
- set_client_executor(fake_executor)
- try:
- tokens = await _collect_stream(
- agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
- )
- finally:
- clear_client_executor()
-
- assert tokens == ["Task", " ", "found."]
- assert len(agent.tool_results) == 1
- assert agent.tool_results[0] == raw_result
-
-
-# ── _tool_loop_stream: tool_results reset on each call ───────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_stream_resets_tool_results():
- agent = _EchoAgent()
- agent.tool_results = [{"old": True}]
-
- no_tool_msg = _make_ai_message("")
- llm = AsyncMock()
- llm.ainvoke = AsyncMock(return_value=no_tool_msg)
-
- async def fake_astream(msgs):
- c = MagicMock()
- c.content = "ok"
- yield c
-
- llm.astream = fake_astream
-
- await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
- assert agent.tool_results == []
-
-
-# ── _tool_loop_stream: empty chunk content is skipped ────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_stream_skips_empty_chunks():
- agent = _EchoAgent()
- no_tool_msg = _make_ai_message("")
-
- llm = AsyncMock()
- llm.ainvoke = AsyncMock(return_value=no_tool_msg)
-
- async def fake_astream(msgs):
- for tok in ["", "hello", "", " world", ""]:
- c = MagicMock()
- c.content = tok
- yield c
-
- llm.astream = fake_astream
-
- tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
- assert tokens == ["hello", " world"]
-
-
-# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_stream_max_iter():
- agent = _EchoAgent()
-
- always_tool = _make_ai_message(
- tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
- )
-
- llm = MagicMock()
- llm_with_tools = MagicMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
-
- async def fake_astream(msgs):
- c = MagicMock()
- c.content = "fallback"
- yield c
-
- llm.astream = fake_astream
- mock_tool = _make_tool("t", "ok")
-
- tokens = await _collect_stream(
- agent, llm, [HumanMessage(content="x")], [mock_tool],
- )
- assert tokens == ["fallback"]
- assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
-
-
-# ── _tool_loop_stream: multiple tool results captured ────────────────
-
-
-@pytest.mark.asyncio
-async def test_tool_loop_stream_multiple_tool_results():
- agent = _EchoAgent()
-
- call_results = [
- {"rows": [{"id": "t-1"}]},
- {"rows": [{"id": "t-2"}]},
- ]
- call_iter = iter(call_results)
-
- async def fake_executor(payload: dict) -> dict:
- return next(call_iter)
-
- # Two tool calls in one iteration
- tool_call_msg = _make_ai_message(
- tool_calls=[
- {"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
- {"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
- ]
- )
- no_more_tools_msg = _make_ai_message("Done.")
-
- llm = MagicMock()
- llm_with_tools = MagicMock()
- llm.bind_tools = MagicMock(return_value=llm_with_tools)
- llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
-
- async def fake_astream(msgs):
- c = MagicMock()
- c.content = "Done."
- yield c
-
- llm.astream = fake_astream
-
- async def tool_side_effect(args: dict) -> str:
- from app.core.ws_context import execute_on_client
- res = await execute_on_client(action="select", table="tasks")
- return str(res)
-
- tool_a = _make_tool("tool_a", "")
- tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
- tool_b = _make_tool("tool_b", "")
- tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
-
- from app.core.ws_context import set_client_executor, clear_client_executor
- set_client_executor(fake_executor)
- try:
- tokens = await _collect_stream(
- agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
- )
- finally:
- clear_client_executor()
-
- assert tokens == ["Done."]
- assert len(agent.tool_results) == 2
- assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
- assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
diff --git a/tests/test_agents.py b/tests/test_agents.py
deleted file mode 100644
index 4023232..0000000
--- a/tests/test_agents.py
+++ /dev/null
@@ -1,761 +0,0 @@
-"""Unit tests for the four domain-specific chat agents with mocked LLM."""
-
-from __future__ import annotations
-
-import json
-from typing import Any
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-
-import app.agents # noqa: F401 — triggers @registry.register decorators
-from app.agents.timeline_agent import TimelineAgent
-from app.agents.note_agent import NoteAgent
-from app.agents.project_agent import ProjectAgent
-from app.agents.task_agent import TaskAgent
-from app.core.agent_registry import registry
-from app.core.ws_context import clear_client_executor, set_client_executor
-
-
-# ── WS executor mock ──────────────────────────────────────────────────
-#
-# Tools call execute_on_client() which reads a ContextVar set by the WS
-# handler. In unit tests there is no WS session, so we install a fake
-# executor that returns plausible data for each action type.
-
-_FAKE_ROW: dict[str, Any] = {
- "id": "fake-id",
- "title": "Fake Title",
- "name": "Fake Name",
- "status": "todo",
- "priority": "medium",
- "content": "Fake content",
- "date": 1700000000000,
- "taskId": "fake-task-id",
- "author": "Alice",
- "projectId": None,
-}
-
-
-async def _fake_executor(payload: dict) -> dict:
- action = payload.get("action", "")
- if action == "select":
- return {"rows": []}
- if action == "insert":
- data = payload.get("data", {})
- return {"row": {**_FAKE_ROW, **data}}
- if action == "update":
- data = payload.get("data", {})
- row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
- return {"row": row}
- if action == "delete":
- return {"deleted": True}
- if action == "get":
- data = payload.get("data", {})
- return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
- if action == "vector_upsert":
- return {"ok": True}
- return {}
-
-
-@pytest.fixture(autouse=True)
-def ws_executor():
- """Install a fake WS executor for every test so tools can run without a real WS."""
- set_client_executor(_fake_executor)
- yield
- clear_client_executor()
-
-
-# ── Helpers ──────────────────────────────────────────────────────────
-
-
-def _mock_llm(response_text: str) -> MagicMock:
- """Return a mock LLM that responds with *response_text* (no tool calls)."""
- msg = MagicMock()
- msg.content = response_text
- msg.tool_calls = []
- llm = MagicMock()
- bound = MagicMock()
- bound.ainvoke = AsyncMock(return_value=msg)
- llm.bind_tools = MagicMock(return_value=bound)
- llm.ainvoke = AsyncMock(return_value=msg)
- return llm
-
-
-def _mock_llm_with_tool_call(
- tool_name: str, tool_args: dict[str, Any], final_text: str
-) -> MagicMock:
- """Mock LLM that fires one tool call then returns *final_text*."""
- tool_msg = MagicMock()
- tool_msg.content = ""
- tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
-
- final_msg = MagicMock()
- final_msg.content = final_text
- final_msg.tool_calls = []
-
- bound = MagicMock()
- bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
-
- llm = MagicMock()
- llm.bind_tools = MagicMock(return_value=bound)
- llm.ainvoke = AsyncMock(return_value=final_msg)
- return llm
-
-
-# ── Registration ──────────────────────────────────────────────────────
-
-
-class TestAgentRegistration:
- def test_all_agents_registered(self) -> None:
- names = {a["name"] for a in registry.list_agents()}
- assert {
- "task_agent", "timeline_agent", "project_agent", "note_agent"
- }.issubset(names)
-
- def test_registry_returns_correct_types(self) -> None:
- assert isinstance(registry.get("task_agent"), TaskAgent)
- assert isinstance(registry.get("timeline_agent"), TimelineAgent)
- assert isinstance(registry.get("project_agent"), ProjectAgent)
- assert isinstance(registry.get("note_agent"), NoteAgent)
-
- def test_descriptions_present(self) -> None:
- for agent_info in registry.list_agents():
- assert agent_info["description"], f"Empty description: {agent_info['name']}"
-
-
-# ── TaskAgent ─────────────────────────────────────────────────────────
-
-
-class TestTaskAgent:
- def test_name(self) -> None:
- assert TaskAgent().get_name() == "task_agent"
-
- def test_description(self) -> None:
- assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
-
- def test_get_tools_count(self) -> None:
- assert len(TaskAgent().get_tools()) == 8
-
- def test_tool_names(self) -> None:
- names = {t.name for t in TaskAgent().get_tools()}
- assert names == {
- "list_tasks",
- "create_task",
- "update_task",
- "delete_task",
- "list_tasks_due_today",
- "list_task_comments",
- "add_task_comment",
- "delete_task_comment",
- }
-
- @pytest.mark.asyncio
- async def test_handle_returns_string(self) -> None:
- with patch("app.agents.task_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Task created.")
- result = await TaskAgent().handle("create a task", {})
- assert isinstance(result, str)
-
- @pytest.mark.asyncio
- async def test_handle_no_tool_calls(self) -> None:
- with patch("app.agents.task_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Here are your tasks.")
- result = await TaskAgent().handle("list my tasks", {})
- assert result == "Here are your tasks."
-
- @pytest.mark.asyncio
- async def test_handle_with_create_task_tool_call(self) -> None:
- with patch("app.agents.task_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm_with_tool_call(
- "create_task",
- {"title": "Buy groceries", "priority": "low"},
- "Task 'Buy groceries' created.",
- )
- result = await TaskAgent().handle("add a grocery task", {})
- assert result == "Task 'Buy groceries' created."
-
- @pytest.mark.asyncio
- async def test_handle_accepts_empty_context(self) -> None:
- with patch("app.agents.task_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Done.")
- result = await TaskAgent().handle("help", {})
- assert isinstance(result, str)
-
- @pytest.mark.asyncio
- async def test_handle_accepts_rich_context(self) -> None:
- context = {
- "user_profile": {"id": "u1", "tier": "pro"},
- "recent_tasks": [{"id": "t1", "title": "Old task"}],
- }
- with patch("app.agents.task_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Tasks listed.")
- result = await TaskAgent().handle("show tasks", context)
- assert isinstance(result, str)
-
-
-class TestTaskAgentTools:
- @pytest.mark.asyncio
- async def test_list_tasks_defaults(self) -> None:
- from app.agents.task_agent import list_tasks
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_tasks.ainvoke({})
- m.assert_called_once_with(
- action="select", table="tasks",
- filters={"projectId": None, "status": None, "search": None, "orderBy": None},
- )
- assert result == "No tasks found matching the given filters."
-
- @pytest.mark.asyncio
- async def test_list_tasks_with_status_filter(self) -> None:
- from app.agents.task_agent import list_tasks
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- await list_tasks.ainvoke({"status": "done"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["filters"]["status"] == "done"
-
- @pytest.mark.asyncio
- async def test_create_task_defaults(self) -> None:
- from app.agents.task_agent import create_task
- fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await create_task.ainvoke({"title": "Test task"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "insert"
- assert call_kwargs["table"] == "tasks"
- assert call_kwargs["data"]["title"] == "Test task"
- assert call_kwargs["data"]["status"] == "todo"
- assert call_kwargs["data"]["priority"] == "medium"
- assert "Test task" in result
-
- @pytest.mark.asyncio
- async def test_create_task_with_all_fields(self) -> None:
- from app.agents.task_agent import create_task
- fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await create_task.ainvoke({
- "title": "Deploy", "priority": "high", "status": "in_progress",
- "project_id": "p1", "is_ai_suggested": 1,
- })
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["data"]["priority"] == "high"
- assert call_kwargs["data"]["status"] == "in_progress"
- assert call_kwargs["data"]["projectId"] == "p1"
- assert call_kwargs["data"]["isAiSuggested"] == 1
-
- @pytest.mark.asyncio
- async def test_update_task_with_status(self) -> None:
- from app.agents.task_agent import update_task
- fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "update"
- assert call_kwargs["data"]["id"] == "t1"
- assert call_kwargs["data"]["updates"]["status"] == "done"
- assert "t1" in result
-
- @pytest.mark.asyncio
- async def test_update_task_empty_updates(self) -> None:
- from app.agents.task_agent import update_task
- fake_row = {"id": "t1", "title": "Task", "status": "todo"}
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await update_task.ainvoke({"task_id": "t1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["data"]["updates"] == {}
-
- @pytest.mark.asyncio
- async def test_delete_task(self) -> None:
- from app.agents.task_agent import delete_task
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"deleted": True}
- result = await delete_task.ainvoke({"task_id": "t1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "delete"
- assert call_kwargs["table"] == "tasks"
- assert call_kwargs["data"]["id"] == "t1"
- assert "t1" in result
-
- @pytest.mark.asyncio
- async def test_list_tasks_due_today(self) -> None:
- from app.agents.task_agent import list_tasks_due_today
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_tasks_due_today.ainvoke({})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "select"
- assert call_kwargs["table"] == "tasks"
- assert "dueDateFrom" in call_kwargs["filters"]
- assert result == "No tasks are due today."
-
- @pytest.mark.asyncio
- async def test_list_task_comments(self) -> None:
- from app.agents.task_agent import list_task_comments
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_task_comments.ainvoke({"task_id": "t1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "select"
- assert call_kwargs["table"] == "taskComments"
- assert call_kwargs["filters"]["taskId"] == "t1"
- assert "t1" in result
-
- @pytest.mark.asyncio
- async def test_add_task_comment(self) -> None:
- from app.agents.task_agent import add_task_comment
- fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await add_task_comment.ainvoke({
- "task_id": "t1", "author": "Alice", "content": "Looks good!",
- })
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "insert"
- assert call_kwargs["table"] == "taskComments"
- assert call_kwargs["data"]["taskId"] == "t1"
- assert call_kwargs["data"]["author"] == "Alice"
- assert call_kwargs["data"]["content"] == "Looks good!"
- assert "Alice" in result
-
- @pytest.mark.asyncio
- async def test_delete_task_comment(self) -> None:
- from app.agents.task_agent import delete_task_comment
- with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"deleted": True}
- result = await delete_task_comment.ainvoke({"comment_id": "c1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "delete"
- assert call_kwargs["table"] == "taskComments"
- assert call_kwargs["data"]["id"] == "c1"
- assert "c1" in result
-
-
-# ── TimelineAgent ───────────────────────────────────────────────────
-
-
-class TestTimelineAgent:
- def test_name(self) -> None:
- assert TimelineAgent().get_name() == "timeline_agent"
-
- def test_description(self) -> None:
- assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete"
-
- def test_get_tools_count(self) -> None:
- assert len(TimelineAgent().get_tools()) == 4
-
- def test_tool_names(self) -> None:
- names = {t.name for t in TimelineAgent().get_tools()}
- assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"}
-
- @pytest.mark.asyncio
- async def test_handle_no_tool_calls(self) -> None:
- with patch("app.agents.timeline_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("No timelines found.")
- result = await TimelineAgent().handle("list timelines", {})
- assert result == "No timelines found."
-
- @pytest.mark.asyncio
- async def test_handle_with_create_tool_call(self) -> None:
- with patch("app.agents.timeline_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm_with_tool_call(
- "create_timeline",
- {"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
- "Timeline 'MVP Launch' created.",
- )
- result = await TimelineAgent().handle("add MVP timeline", {})
- assert result == "Timeline 'MVP Launch' created."
-
- @pytest.mark.asyncio
- async def test_handle_accepts_empty_context(self) -> None:
- with patch("app.agents.timeline_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Done.")
- result = await TimelineAgent().handle("show milestones", {})
- assert isinstance(result, str)
-
-
-class TestTimelineAgentTools:
- @pytest.mark.asyncio
- async def test_list_timelines_no_project(self) -> None:
- from app.agents.timeline_agent import list_timelines
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_timelines.ainvoke({})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "select"
- assert call_kwargs["table"] == "timelines"
- assert call_kwargs["filters"]["projectId"] is None
- assert result == "No timelines found."
-
- @pytest.mark.asyncio
- async def test_list_timelines_with_project(self) -> None:
- from app.agents.timeline_agent import list_timelines
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- await list_timelines.ainvoke({"project_id": "p1"})
- assert m.call_args.kwargs["filters"]["projectId"] == "p1"
-
- @pytest.mark.asyncio
- async def test_create_timeline(self) -> None:
- from app.agents.timeline_agent import create_timeline
- fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await create_timeline.ainvoke({
- "project_id": "p1", "title": "Beta release", "date": 1700000000000,
- })
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "insert"
- assert call_kwargs["table"] == "timelines"
- assert call_kwargs["data"]["projectId"] == "p1"
- assert call_kwargs["data"]["title"] == "Beta release"
- assert call_kwargs["data"]["date"] == 1700000000000
- assert "Beta release" in result
-
- @pytest.mark.asyncio
- async def test_create_timeline_ai_suggested(self) -> None:
- from app.agents.timeline_agent import create_timeline
- fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await create_timeline.ainvoke({
- "project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
- })
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["data"]["isAiSuggested"] == 1
- assert call_kwargs["data"]["isApproved"] == 0
-
- @pytest.mark.asyncio
- async def test_update_timeline_approve(self) -> None:
- from app.agents.timeline_agent import update_timeline
- fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "update"
- assert call_kwargs["data"]["id"] == "c1"
- assert call_kwargs["data"]["updates"]["isApproved"] == 1
- assert "c1" in result
-
- @pytest.mark.asyncio
- async def test_update_timeline_empty_updates(self) -> None:
- from app.agents.timeline_agent import update_timeline
- fake_row = {"id": "c1", "title": "MVP"}
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await update_timeline.ainvoke({"timeline_id": "c1"})
- assert m.call_args.kwargs["data"]["updates"] == {}
-
- @pytest.mark.asyncio
- async def test_delete_timeline(self) -> None:
- from app.agents.timeline_agent import delete_timeline
- with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"deleted": True}
- result = await delete_timeline.ainvoke({"timeline_id": "c1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "delete"
- assert call_kwargs["table"] == "timelines"
- assert call_kwargs["data"]["id"] == "c1"
- assert "c1" in result
-
-
-# ── ProjectAgent ──────────────────────────────────────────────────────
-
-
-class TestProjectAgent:
- def test_name(self) -> None:
- assert ProjectAgent().get_name() == "project_agent"
-
- def test_description(self) -> None:
- assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
-
- def test_get_tools_count(self) -> None:
- assert len(ProjectAgent().get_tools()) == 6
-
- def test_tool_names(self) -> None:
- names = {t.name for t in ProjectAgent().get_tools()}
- assert names == {
- "list_projects",
- "list_all_projects",
- "get_project",
- "create_project",
- "update_project",
- "delete_project",
- }
-
- @pytest.mark.asyncio
- async def test_handle_no_tool_calls(self) -> None:
- with patch("app.agents.project_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Project Alpha is active.")
- result = await ProjectAgent().handle("show my projects", {})
- assert result == "Project Alpha is active."
-
- @pytest.mark.asyncio
- async def test_handle_with_create_project_tool_call(self) -> None:
- with patch("app.agents.project_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm_with_tool_call(
- "create_project",
- {"name": "Pippo"},
- "Project 'Pippo' created.",
- )
- result = await ProjectAgent().handle("create project Pippo", {})
- assert result == "Project 'Pippo' created."
-
- @pytest.mark.asyncio
- async def test_handle_accepts_empty_context(self) -> None:
- with patch("app.agents.project_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Done.")
- result = await ProjectAgent().handle("archive old project", {})
- assert isinstance(result, str)
-
-
-class TestProjectAgentTools:
- @pytest.mark.asyncio
- async def test_list_projects_defaults(self) -> None:
- from app.agents.project_agent import list_projects
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_projects.ainvoke({})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "select"
- assert call_kwargs["table"] == "projects"
- assert call_kwargs["filters"]["includeArchived"] is False
- assert result == "No projects found."
-
- @pytest.mark.asyncio
- async def test_list_projects_include_archived(self) -> None:
- from app.agents.project_agent import list_projects
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- await list_projects.ainvoke({"include_archived": 1})
- assert m.call_args.kwargs["filters"]["includeArchived"] is True
-
- @pytest.mark.asyncio
- async def test_list_all_projects(self) -> None:
- from app.agents.project_agent import list_all_projects
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_all_projects.ainvoke({})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "select"
- assert call_kwargs["table"] == "projects"
- assert result == "No projects found."
-
- @pytest.mark.asyncio
- async def test_get_project(self) -> None:
- from app.agents.project_agent import get_project
- fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await get_project.ainvoke({"project_id": "p1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "get"
- assert call_kwargs["table"] == "projects"
- assert call_kwargs["data"]["id"] == "p1"
- assert "Alpha" in result
-
- @pytest.mark.asyncio
- async def test_create_project_name_only(self) -> None:
- from app.agents.project_agent import create_project
- fake_row = {"id": "p1", "name": "Alpha"}
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await create_project.ainvoke({"name": "Alpha"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "insert"
- assert call_kwargs["data"]["name"] == "Alpha"
- assert call_kwargs["data"]["clientId"] is None
- assert "Alpha" in result
-
- @pytest.mark.asyncio
- async def test_create_project_with_client(self) -> None:
- from app.agents.project_agent import create_project
- fake_row = {"id": "p1", "name": "Beta"}
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
- assert m.call_args.kwargs["data"]["clientId"] == "cl1"
-
- @pytest.mark.asyncio
- async def test_update_project_archive(self) -> None:
- from app.agents.project_agent import update_project
- fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "update"
- assert call_kwargs["data"]["id"] == "p1"
- assert call_kwargs["data"]["updates"]["status"] == "archived"
- assert "p1" in result
-
- @pytest.mark.asyncio
- async def test_update_project_empty_updates(self) -> None:
- from app.agents.project_agent import update_project
- fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await update_project.ainvoke({"project_id": "p1"})
- assert m.call_args.kwargs["data"]["updates"] == {}
-
- @pytest.mark.asyncio
- async def test_delete_project(self) -> None:
- from app.agents.project_agent import delete_project
- with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"deleted": True}
- result = await delete_project.ainvoke({"project_id": "p1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "delete"
- assert call_kwargs["data"]["id"] == "p1"
- assert "p1" in result
-
-
-# ── NoteAgent ─────────────────────────────────────────────────────────
-
-
-class TestNoteAgent:
- def test_name(self) -> None:
- assert NoteAgent().get_name() == "note_agent"
-
- def test_description(self) -> None:
- assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
-
- def test_get_tools_count(self) -> None:
- assert len(NoteAgent().get_tools()) == 5
-
- def test_tool_names(self) -> None:
- names = {t.name for t in NoteAgent().get_tools()}
- assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
-
- @pytest.mark.asyncio
- async def test_handle_no_tool_calls(self) -> None:
- with patch("app.agents.note_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Note created.")
- result = await NoteAgent().handle("create a note", {})
- assert result == "Note created."
-
- @pytest.mark.asyncio
- async def test_handle_with_create_note_tool_call(self) -> None:
- with patch("app.agents.note_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm_with_tool_call(
- "create_note",
- {"title": "Daily log", "content": "# Today\nAll good."},
- "Note 'Daily log' created.",
- )
- result = await NoteAgent().handle("log today's progress", {})
- assert result == "Note 'Daily log' created."
-
- @pytest.mark.asyncio
- async def test_handle_accepts_empty_context(self) -> None:
- with patch("app.agents.note_agent.get_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("Done.")
- result = await NoteAgent().handle("show notes", {})
- assert isinstance(result, str)
-
-
-class TestNoteAgentTools:
- @pytest.mark.asyncio
- async def test_list_notes_no_project(self) -> None:
- from app.agents.note_agent import list_notes
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- result = await list_notes.ainvoke({})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "select"
- assert call_kwargs["table"] == "notes"
- assert call_kwargs["filters"]["projectId"] is None
- assert result == "No notes found."
-
- @pytest.mark.asyncio
- async def test_list_notes_with_project(self) -> None:
- from app.agents.note_agent import list_notes
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"rows": []}
- await list_notes.ainvoke({"project_id": "p1"})
- assert m.call_args.kwargs["filters"]["projectId"] == "p1"
-
- @pytest.mark.asyncio
- async def test_get_note(self) -> None:
- from app.agents.note_agent import get_note
- fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- result = await get_note.ainvoke({"note_id": "n1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "get"
- assert call_kwargs["table"] == "notes"
- assert call_kwargs["data"]["id"] == "n1"
- assert "Daily log" in result
-
- @pytest.mark.asyncio
- async def test_create_note_minimal(self) -> None:
- from app.agents.note_agent import create_note
- fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
- patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
- m.return_value = {"row": fake_row}
- me.return_value = [0.0] * 1536
- result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
- # First call: insert; second call: vector_upsert
- first_call = m.call_args_list[0].kwargs
- assert first_call["action"] == "insert"
- assert first_call["table"] == "notes"
- assert first_call["data"]["title"] == "Daily log"
- assert first_call["data"]["content"] == "# Today\nAll good."
- assert first_call["data"]["projectId"] is None
- assert "Daily log" in result
-
- @pytest.mark.asyncio
- async def test_create_note_with_project(self) -> None:
- from app.agents.note_agent import create_note
- fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
- patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
- m.return_value = {"row": fake_row}
- me.return_value = [0.0] * 1536
- await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
- first_call = m.call_args_list[0].kwargs
- assert first_call["data"]["projectId"] == "p1"
-
- @pytest.mark.asyncio
- async def test_update_note_content_only(self) -> None:
- from app.agents.note_agent import update_note
- fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
- patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
- m.return_value = {"row": fake_row}
- me.return_value = [0.0] * 1536
- result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
- first_call = m.call_args_list[0].kwargs
- assert first_call["action"] == "update"
- assert first_call["data"]["id"] == "n1"
- assert first_call["data"]["updates"]["content"] == "# Updated content"
- assert "title" not in first_call["data"]["updates"]
- assert "n1" in result
-
- @pytest.mark.asyncio
- async def test_update_note_empty_updates(self) -> None:
- from app.agents.note_agent import update_note
- fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"row": fake_row}
- await update_note.ainvoke({"note_id": "n1"})
- assert m.call_args.kwargs["data"]["updates"] == {}
-
- @pytest.mark.asyncio
- async def test_delete_note(self) -> None:
- from app.agents.note_agent import delete_note
- with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
- m.return_value = {"deleted": True}
- result = await delete_note.ainvoke({"note_id": "n1"})
- call_kwargs = m.call_args.kwargs
- assert call_kwargs["action"] == "delete"
- assert call_kwargs["table"] == "notes"
- assert call_kwargs["data"]["id"] == "n1"
- assert "n1" in result
diff --git a/tests/test_execution_plan.py b/tests/test_execution_plan.py
deleted file mode 100644
index 06a2bfa..0000000
--- a/tests/test_execution_plan.py
+++ /dev/null
@@ -1,286 +0,0 @@
-"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
-
-from __future__ import annotations
-
-import pytest
-
-from app.core.execution_plan import (
- ExecutionPlanBuilder,
- PlanCache,
- PromptTemplateRegistry,
- plan_cache,
- template_registry,
-)
-from app.schemas import ExecutionPlan
-
-
-# ── PromptTemplateRegistry ────────────────────────────────────────────
-
-
-class TestPromptTemplateRegistry:
- def test_register_and_get(self) -> None:
- reg = PromptTemplateRegistry()
- reg.register("tpl_foo", "You are a foo agent.")
- assert reg.get("tpl_foo") == "You are a foo agent."
-
- def test_get_unknown_raises_key_error(self) -> None:
- reg = PromptTemplateRegistry()
- with pytest.raises(KeyError, match="tpl_missing"):
- reg.get("tpl_missing")
-
- def test_has_returns_true_for_registered(self) -> None:
- reg = PromptTemplateRegistry()
- reg.register("tpl_x", "prompt text")
- assert reg.has("tpl_x") is True
-
- def test_has_returns_false_for_unregistered(self) -> None:
- reg = PromptTemplateRegistry()
- assert reg.has("tpl_missing") is False
-
- def test_list_ids_returns_all_registered_ids(self) -> None:
- reg = PromptTemplateRegistry()
- reg.register("tpl_a", "a")
- reg.register("tpl_b", "b")
- assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
-
- def test_list_ids_does_not_return_prompt_text(self) -> None:
- reg = PromptTemplateRegistry()
- reg.register("tpl_secret", "top secret prompt")
- ids = reg.list_ids()
- assert "top secret prompt" not in ids
-
- def test_overwrite_existing_template(self) -> None:
- reg = PromptTemplateRegistry()
- reg.register("tpl_x", "v1")
- reg.register("tpl_x", "v2")
- assert reg.get("tpl_x") == "v2"
-
- def test_empty_registry_has_no_ids(self) -> None:
- reg = PromptTemplateRegistry()
- assert reg.list_ids() == []
-
-
-# ── ExecutionPlanBuilder ──────────────────────────────────────────────
-
-
-class TestExecutionPlanBuilder:
- def test_builds_empty_plan(self) -> None:
- plan = ExecutionPlanBuilder("task_agent").build()
- assert plan.agent == "task_agent"
- assert plan.steps == []
-
- def test_add_step_basic(self) -> None:
- plan = (
- ExecutionPlanBuilder("task_agent")
- .add_step("create_task", {"priority": "high"})
- .build()
- )
- assert len(plan.steps) == 1
- assert plan.steps[0].action == "create_task"
- assert plan.steps[0].variables == {"priority": "high"}
- assert plan.steps[0].prompt_template is None
- assert plan.steps[0].data_from_step is None
-
- def test_add_step_no_params(self) -> None:
- plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
- assert plan.steps[0].variables is None
-
- def test_add_llm_step(self) -> None:
- plan = (
- ExecutionPlanBuilder("task_agent")
- .add_llm_step("tpl_task_default", {"message": "hi"})
- .build()
- )
- assert plan.steps[0].action == "llm"
- assert plan.steps[0].prompt_template == "tpl_task_default"
- assert plan.steps[0].variables == {"message": "hi"}
-
- def test_add_llm_step_no_variables(self) -> None:
- plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
- assert plan.steps[0].variables is None
-
- def test_add_data_step(self) -> None:
- plan = (
- ExecutionPlanBuilder("task_agent")
- .add_step("fetch_data")
- .add_data_step("transform", data_from_step=0)
- .build()
- )
- assert plan.steps[1].action == "transform"
- assert plan.steps[1].data_from_step == 0
-
- def test_fluent_chaining_returns_builder(self) -> None:
- builder = ExecutionPlanBuilder("analytics_agent")
- result = builder.add_step("a")
- assert result is builder
-
- def test_fluent_chain_multiple_steps(self) -> None:
- plan = (
- ExecutionPlanBuilder("analytics_agent")
- .add_llm_step("tpl_analytics_default")
- .add_step("format_output")
- .add_data_step("store", data_from_step=0)
- .build()
- )
- assert len(plan.steps) == 3
-
- def test_build_validates_data_from_step_out_of_range(self) -> None:
- with pytest.raises(ValueError, match="data_from_step"):
- ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
-
- def test_build_validates_data_from_step_self_reference(self) -> None:
- """data_from_step=0 on the first step (index 0) is invalid."""
- with pytest.raises(ValueError, match="data_from_step"):
- ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
-
- def test_build_validates_data_from_step_negative(self) -> None:
- with pytest.raises(ValueError, match="data_from_step"):
- ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
-
- def test_valid_data_from_step_at_index_two(self) -> None:
- plan = (
- ExecutionPlanBuilder("task_agent")
- .add_step("step0")
- .add_step("step1")
- .add_data_step("step2", data_from_step=1)
- .build()
- )
- assert plan.steps[2].data_from_step == 1
-
- def test_data_from_step_zero_valid_at_index_one(self) -> None:
- plan = (
- ExecutionPlanBuilder("task_agent")
- .add_step("step0")
- .add_data_step("step1", data_from_step=0)
- .build()
- )
- assert plan.steps[1].data_from_step == 0
-
- def test_build_returns_new_plan_each_call(self) -> None:
- builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
- plan1 = builder.build()
- plan2 = builder.build()
- assert plan1 is not plan2
- assert plan1.steps == plan2.steps
-
- def test_plan_is_execution_plan_instance(self) -> None:
- plan = ExecutionPlanBuilder("task_agent").build()
- assert isinstance(plan, ExecutionPlan)
-
-
-# ── PlanCache ─────────────────────────────────────────────────────────
-
-
-class TestPlanCache:
- def _plan(self, agent: str = "a") -> ExecutionPlan:
- return ExecutionPlanBuilder(agent).build()
-
- def test_cache_and_get(self) -> None:
- cache = PlanCache()
- plan = self._plan()
- cache.cache_plan("key1", plan)
- assert cache.get_plan("key1") is plan
-
- def test_get_missing_returns_none(self) -> None:
- cache = PlanCache()
- assert cache.get_plan("nonexistent") is None
-
- def test_get_all_playbooks_empty(self) -> None:
- cache = PlanCache()
- assert cache.get_all_playbooks() == []
-
- def test_get_all_playbooks_returns_all_stored(self) -> None:
- cache = PlanCache()
- p1, p2 = self._plan("a"), self._plan("b")
- cache.cache_plan("k1", p1)
- cache.cache_plan("k2", p2)
- playbooks = cache.get_all_playbooks()
- assert len(playbooks) == 2
- assert p1 in playbooks
- assert p2 in playbooks
-
- def test_lru_evicts_oldest_entry(self) -> None:
- cache = PlanCache(maxsize=2)
- p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
- cache.cache_plan("k1", p1)
- cache.cache_plan("k2", p2)
- cache.cache_plan("k3", p3) # k1 should be evicted
- assert cache.get_plan("k1") is None
- assert cache.get_plan("k2") is p2
- assert cache.get_plan("k3") is p3
-
- def test_lru_access_updates_recency(self) -> None:
- cache = PlanCache(maxsize=2)
- p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
- cache.cache_plan("k1", p1)
- cache.cache_plan("k2", p2)
- cache.get_plan("k1") # k1 is now most-recently used
- cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
- assert cache.get_plan("k1") is p1
- assert cache.get_plan("k2") is None
- assert cache.get_plan("k3") is p3
-
- def test_overwrite_existing_key(self) -> None:
- cache = PlanCache()
- p1, p2 = self._plan("a"), self._plan("b")
- cache.cache_plan("same_key", p1)
- cache.cache_plan("same_key", p2)
- assert cache.get_plan("same_key") is p2
- assert len(cache.get_all_playbooks()) == 1
-
- def test_overwrite_does_not_consume_capacity(self) -> None:
- cache = PlanCache(maxsize=2)
- p1, p2 = self._plan("a"), self._plan("b")
- cache.cache_plan("k1", p1)
- cache.cache_plan("k1", p2) # overwrite, not a new slot
- cache.cache_plan("k2", p1) # should fit without eviction
- assert cache.get_plan("k1") is p2
- assert cache.get_plan("k2") is p1
-
-
-# ── Module-level singletons ───────────────────────────────────────────
-
-
-class TestModuleSingletons:
- def test_template_registry_has_all_agent_defaults(self) -> None:
- for agent in ("task_agent", "timeline_agent", "project_agent", "note_agent"):
- assert template_registry.has(f"tpl_{agent}_default"), (
- f"Missing template: tpl_{agent}_default"
- )
-
- def test_template_registry_has_operation_templates(self) -> None:
- assert template_registry.has("tpl_task_extract_from_project")
- assert template_registry.has("tpl_note_weekly_summary")
-
- def test_template_registry_get_returns_non_empty_string(self) -> None:
- text = template_registry.get("tpl_task_agent_default")
- assert isinstance(text, str)
- assert len(text) > 0
-
- def test_plan_cache_has_prebuilt_playbooks(self) -> None:
- assert len(plan_cache.get_all_playbooks()) >= 2
-
- def test_playbook_create_tasks_from_project(self) -> None:
- plan = plan_cache.get_plan("create_tasks_from_project")
- assert plan is not None
- assert plan.agent == "project_agent"
- assert len(plan.steps) == 2
- assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
- assert plan.steps[1].data_from_step == 0
-
- def test_playbook_generate_weekly_note(self) -> None:
- plan = plan_cache.get_plan("generate_weekly_note")
- assert plan is not None
- assert plan.agent == "note_agent"
- assert len(plan.steps) == 2
- assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
- assert plan.steps[1].data_from_step == 0
-
- def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
- """Plans must not embed prompt text — only template IDs."""
- for plan in plan_cache.get_all_playbooks():
- for step in plan.steps:
- if step.prompt_template is not None:
- assert step.prompt_template.startswith("tpl_"), (
- f"prompt_template looks like raw text: {step.prompt_template!r}"
- )
diff --git a/tests/test_memory_middleware.py b/tests/test_memory_middleware.py
index ea5f558..e1b53cd 100644
--- a/tests/test_memory_middleware.py
+++ b/tests/test_memory_middleware.py
@@ -250,15 +250,14 @@ def test_home_request_calls_memory_middleware(client):
token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4())
- async def _mock_stream(user_id, message, context, reg=None):
+ async def _mock_stream(user_id, message, context):
# Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"}
- yield "task_agent", ""
- yield "task_agent", '{"type": "text", "content": "Done"}'
+ yield "token", "Done"
with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
- patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream),
+ patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
index 8721bbc..576a145 100644
--- a/tests/test_middleware.py
+++ b/tests/test_middleware.py
@@ -20,7 +20,6 @@ from jose import jwt
from app.config.settings import settings
from app.db import get_session
from app.main import app
-from app.schemas import ChatResponse
from tests.conftest import TEST_USER_IDS
# ---------------------------------------------------------------------------
@@ -50,7 +49,6 @@ _CHAT_BODY = {
"recent_tasks": [],
"conversation_history": [],
},
- "execution_mode": "direct",
}
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
class TestSanitizerMiddleware:
- """Mock ``orchestrate`` to inject controlled strings into chat responses."""
+ """Mock ``run_home`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat"
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict:
- mock_response = ChatResponse(response=response_text, actions=[])
with patch(
- "app.api.routes.chat.orchestrate",
+ "app.api.routes.chat.run_home",
new_callable=AsyncMock,
- return_value=mock_response,
+ return_value=response_text,
):
resp = client.post(
self._CHAT_PATH,
diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py
deleted file mode 100644
index 07576d4..0000000
--- a/tests/test_orchestrator.py
+++ /dev/null
@@ -1,347 +0,0 @@
-"""Integration tests for the orchestrator module."""
-
-from __future__ import annotations
-
-import json
-from typing import Any
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-
-from app.core.agent_registry import AgentRegistry, ChatAgent
-from app.core.orchestrator import (
- classify_intent,
- orchestrate,
- orchestrate_stream,
- route_pipeline,
- route_single,
-)
-from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
-
-
-# ── Stub agents ──────────────────────────────────────────────────────
-
-
-class _TaskAgent(ChatAgent):
- def get_name(self) -> str:
- return "task_agent"
-
- def get_description(self) -> str:
- return "Manages tasks: create, update, list, suggest"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return f"task: {query}"
-
-
-class _CalendarAgent(ChatAgent):
- def get_name(self) -> str:
- return "calendar_agent"
-
- def get_description(self) -> str:
- return "Calendar management: events, conflicts, scheduling"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return f"calendar: {query}"
-
-
-# ── Helpers ──────────────────────────────────────────────────────────
-
-
-def _mock_llm(response_text: str) -> MagicMock:
- """Return a mock LLM that always produces *response_text*."""
- msg = MagicMock()
- msg.content = response_text
- llm = MagicMock()
- llm.ainvoke = AsyncMock(return_value=msg)
- return llm
-
-
-# ── Fixtures ─────────────────────────────────────────────────────────
-
-
-@pytest.fixture(autouse=True)
-def _fresh_registry():
- """Reset the AgentRegistry singleton between tests."""
- AgentRegistry._instance = None
- yield
- AgentRegistry._instance = None
-
-
-@pytest.fixture()
-def reg() -> AgentRegistry:
- r = AgentRegistry()
- r.register(_TaskAgent)
- r.register(_CalendarAgent)
- return r
-
-
-# ── classify_intent ───────────────────────────────────────────────────
-
-
-class TestClassifyIntent:
- @pytest.mark.asyncio
- async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- result = await classify_intent("add a task", {}, reg)
- assert result == "task_agent"
-
- @pytest.mark.asyncio
- async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("calendar_agent")
- result = await classify_intent("schedule a meeting", {}, reg)
- assert result == "calendar_agent"
-
- @pytest.mark.asyncio
- async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("nonexistent_agent")
- result = await classify_intent("do something", {}, reg)
- assert result == "task_agent"
-
- @pytest.mark.asyncio
- async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
- empty_reg = AgentRegistry()
- # No LLM should be instantiated — early return path
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- result = await classify_intent("anything", {}, empty_reg)
- mock_cls.assert_not_called()
- assert result == "task_agent"
-
- @pytest.mark.asyncio
- async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm(" task_agent \n")
- result = await classify_intent("create task", {}, reg)
- assert result == "task_agent"
-
-
-# ── route_single ─────────────────────────────────────────────────────
-
-
-class TestRouteSingle:
- @pytest.mark.asyncio
- async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
- result = await route_single("task_agent", "create a task", {}, reg)
- assert isinstance(result, ChatResponse)
-
- @pytest.mark.asyncio
- async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
- result = await route_single("task_agent", "create a task", {}, reg)
- assert result.response == "task: create a task"
-
- @pytest.mark.asyncio
- async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
- with pytest.raises(KeyError):
- await route_single("nonexistent", "hello", {}, reg)
-
- @pytest.mark.asyncio
- async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
- result = await route_single("task_agent", "hi", {}, reg)
- assert result.actions == []
-
-
-# ── route_pipeline ────────────────────────────────────────────────────
-
-
-class TestRoutePipeline:
- @pytest.mark.asyncio
- async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("synthesized result")
- result = await route_pipeline(
- ["task_agent", "calendar_agent"], "plan my week", {}, reg
- )
- assert isinstance(result, ChatResponse)
-
- @pytest.mark.asyncio
- async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("synthesized result")
- result = await route_pipeline(
- ["task_agent", "calendar_agent"], "plan my week", {}, reg
- )
- assert result.response == "synthesized result"
-
- @pytest.mark.asyncio
- async def test_passes_previous_results_to_subsequent_agents(
- self, reg: AgentRegistry
- ) -> None:
- """Each agent after the first should receive prior outputs in context."""
- received_contexts: list[dict[str, Any]] = []
-
- class _CapturingAgent(ChatAgent):
- def get_name(self) -> str:
- return "capture"
-
- def get_description(self) -> str:
- return "captures context for testing"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- received_contexts.append(dict(context))
- return "captured"
-
- reg.register(_CapturingAgent)
-
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("done")
- await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
-
- # The second agent (capture) must have received previous results
- assert len(received_contexts) == 1
- assert "previous_results" in received_contexts[0]
- assert received_contexts[0]["previous_results"] == ["task: hi"]
-
- @pytest.mark.asyncio
- async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("single result")
- result = await route_pipeline(["task_agent"], "one agent", {}, reg)
- assert result.response == "single result"
-
-
-# ── orchestrate ───────────────────────────────────────────────────────
-
-
-class TestOrchestrate:
- @pytest.mark.asyncio
- async def test_direct_mode_returns_chat_response(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="add a task", execution_mode="direct")
- result = await orchestrate(request, reg)
- assert isinstance(result, ChatResponse)
-
- @pytest.mark.asyncio
- async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="add a task", execution_mode="direct")
- result = await orchestrate(request, reg)
- assert isinstance(result, ChatResponse)
- assert result.response == "task: add a task"
-
- @pytest.mark.asyncio
- async def test_plan_mode_returns_execution_plan(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="plan my tasks", execution_mode="plan")
- result = await orchestrate(request, reg)
- assert isinstance(result, ExecutionPlan)
-
- @pytest.mark.asyncio
- async def test_plan_mode_agent_matches_classified(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("calendar_agent")
- request = ChatRequest(
- message="schedule something", execution_mode="plan"
- )
- result = await orchestrate(request, reg)
- assert isinstance(result, ExecutionPlan)
- assert result.agent == "calendar_agent"
-
- @pytest.mark.asyncio
- async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="plan tasks", execution_mode="plan")
- result = await orchestrate(request, reg)
- assert isinstance(result, ExecutionPlan)
- assert len(result.steps) >= 1
-
- @pytest.mark.asyncio
- async def test_plan_mode_template_id_contains_agent_name(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="plan tasks", execution_mode="plan")
- result = await orchestrate(request, reg)
- assert isinstance(result, ExecutionPlan)
- assert result.steps[0].prompt_template is not None
- assert "task_agent" in result.steps[0].prompt_template
-
- @pytest.mark.asyncio
- async def test_default_execution_mode_is_direct(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- # execution_mode defaults to "direct"
- request = ChatRequest(message="help me")
- result = await orchestrate(request, reg)
- assert isinstance(result, ChatResponse)
-
-
-# ── orchestrate_stream ────────────────────────────────────────────────
-
-
-class TestOrchestrateStream:
- @pytest.mark.asyncio
- async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="add a task", execution_mode="direct")
- chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
- assert len(chunks) >= 1
-
- @pytest.mark.asyncio
- async def test_all_chunks_are_plain_text(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="add a task", execution_mode="direct")
- chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
-
- # orchestrate_stream yields plain text chunks only — no JSON final frame
- for chunk in chunks:
- assert isinstance(chunk, str)
-
- @pytest.mark.asyncio
- async def test_concatenated_chunks_equal_full_response(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(message="create a task", execution_mode="direct")
- chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
-
- full_text = "".join(chunks)
- assert full_text == "task: create a task"
-
- @pytest.mark.asyncio
- async def test_text_chunks_before_final_frame(
- self, reg: AgentRegistry
- ) -> None:
- with patch("app.core.orchestrator._make_llm") as mock_cls:
- mock_cls.return_value = _mock_llm("task_agent")
- request = ChatRequest(
- message="x" * 200, execution_mode="direct"
- ) # long enough to produce multiple chunks
- chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
-
- # All but the last chunk should be plain text (not valid final JSON)
- non_final = chunks[:-1]
- for chunk in non_final:
- try:
- parsed = json.loads(chunk)
- assert parsed.get("done") is not True
- except json.JSONDecodeError:
- pass # plain text chunk — expected
diff --git a/tests/test_orchestrator_v3.py b/tests/test_orchestrator_v3.py
deleted file mode 100644
index fccb8ab..0000000
--- a/tests/test_orchestrator_v3.py
+++ /dev/null
@@ -1,236 +0,0 @@
-"""Tests for v3 orchestrator functions (Step 3)."""
-
-from __future__ import annotations
-
-import pytest
-from unittest.mock import AsyncMock, MagicMock, patch
-from typing import Any
-
-from app.core.agent_registry import ChatAgent, AgentRegistry
-from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
-
-
-# ── Minimal agent for testing ─────────────────────────────────────────
-
-
-class _FixedAgent(ChatAgent):
- def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- self._name = name
- self._tokens = tokens or ["Hello", " world"]
-
- def get_name(self) -> str:
- return self._name
-
- def get_description(self) -> str:
- return "Fixed agent for tests"
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return "".join(self._tokens)
-
- async def handle_stream(self, query: str, context: dict[str, Any]):
- for tok in self._tokens:
- yield tok
-
-
-# ── Mock registry factory ─────────────────────────────────────────────
-
-
-def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
- reg = MagicMock(spec=AgentRegistry)
- reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
- reg.get.return_value = agent
- return reg
-
-
-# ── orchestrate_v3 ────────────────────────────────────────────────────
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_returns_agent_name_and_instance():
- agent = _FixedAgent("task_agent")
- reg = _make_registry("task_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
- name, inst = await orchestrate_v3(
- user_id="u-1", message="fix a bug", context={}, reg=reg
- )
-
- assert name == "task_agent"
- assert inst is agent
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_classify_called_with_message_and_context():
- agent = _FixedAgent("note_agent")
- reg = _make_registry("note_agent", agent)
- ctx = {"some": "context"}
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
- await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
-
- mock_classify.assert_awaited_once()
- call_args = mock_classify.call_args
- assert call_args[0][0] == "take a note"
- assert call_args[0][1] == ctx
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_uses_default_registry_when_none():
- agent = _FixedAgent("task_agent")
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
- patch("app.core.orchestrator._default_registry") as mock_reg:
- mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
- mock_reg.get.return_value = agent
- name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
-
- assert name == "task_agent"
- assert inst is agent
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_get_called_with_agent_name():
- agent = _FixedAgent("timeline_agent")
- reg = _make_registry("timeline_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")):
- await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
-
- reg.get.assert_called_once_with("timeline_agent")
-
-
-# ── orchestrate_v3_stream ─────────────────────────────────────────────
-
-
-async def _collect(gen) -> list[tuple[str, str]]:
- results: list[tuple[str, str]] = []
- async for item in gen:
- results.append(item)
- return results
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
- agent = _FixedAgent("task_agent", tokens=["token1"])
- reg = _make_registry("task_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
- gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
- results = await _collect(gen)
-
- # First item must be (agent_name, "") — domain signal
- assert results[0] == ("task_agent", "")
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
- agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
- reg = _make_registry("task_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
- gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
- results = await _collect(gen)
-
- # All items are (agent_name, token) pairs
- assert all(name == "task_agent" for name, _ in results)
- tokens = [tok for _, tok in results]
- assert tokens[0] == "" # domain signal
- assert tokens[1:] == ["Hello", " ", "world"]
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_stream_different_agent():
- agent = _FixedAgent("note_agent", tokens=["note"])
- reg = _make_registry("note_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
- gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
- results = await _collect(gen)
-
- assert results[0] == ("note_agent", "")
- assert ("note_agent", "note") in results
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_stream_uses_default_registry_when_none():
- agent = _FixedAgent("task_agent", tokens=["x"])
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
- patch("app.core.orchestrator._default_registry") as mock_reg:
- mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
- mock_reg.get.return_value = agent
- gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
- results = await _collect(gen)
-
- assert results[0][0] == "task_agent"
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_stream_empty_token_list():
- """Agent with no tokens still emits the domain signal."""
-
- class _EmptyAgent(_FixedAgent):
- async def handle_stream(self, query: str, context: dict[str, Any]):
- return
- yield # makes it a generator
-
- agent = _EmptyAgent("task_agent", tokens=[])
- reg = _make_registry("task_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
- gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
- results = await _collect(gen)
-
- assert results == [("task_agent", "")] # only domain signal
-
-
-@pytest.mark.asyncio
-async def test_orchestrate_v3_stream_full_text_correct():
- """Concatenating all non-domain tokens reconstructs the full response."""
- tokens = ["The", " ", "task", " ", "is", " ", "done."]
- agent = _FixedAgent("task_agent", tokens=tokens)
- reg = _make_registry("task_agent", agent)
-
- with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
- gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
- results = await _collect(gen)
-
- text = "".join(tok for _, tok in results[1:]) # skip domain signal
- assert text == "The task is done."
-
-
-# ── handle_stream default implementation ─────────────────────────────
-
-
-@pytest.mark.asyncio
-async def test_handle_stream_default_yields_full_response():
- """Default handle_stream yields handle() result as a single chunk."""
-
- class _SimpleAgent(ChatAgent):
- def get_name(self) -> str:
- return "_simple"
-
- def get_description(self) -> str:
- return ""
-
- def get_tools(self) -> list[Any]:
- return []
-
- async def handle(self, query: str, context: dict[str, Any]) -> str:
- return "simple response"
-
- agent = _SimpleAgent()
- tokens = [tok async for tok in agent.handle_stream("q", {})]
- assert tokens == ["simple response"]
-
-
-@pytest.mark.asyncio
-async def test_handle_stream_override_used_by_stream():
- """_FixedAgent.handle_stream override yields individual tokens."""
- agent = _FixedAgent("t", tokens=["a", "b", "c"])
- tokens = [tok async for tok in agent.handle_stream("q", {})]
- assert tokens == ["a", "b", "c"]
diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py
index bfc5c1c..2f06f79 100644
--- a/tests/test_output_formatter.py
+++ b/tests/test_output_formatter.py
@@ -1,195 +1,75 @@
-"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
+"""Tests for app.core.output_formatter.StreamFormatter."""
from __future__ import annotations
import pytest
-from app.core.output_formatter import HomeFormatter, FloatingFormatter
-from app.schemas import (
- WsFloatingDomain,
- WsStreamBlock,
- WsStreamEnd,
- WsStreamStart,
- WsStreamText,
-)
+from app.core.output_formatter import StreamFormatter
+from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
-# ── helpers ───────────────────────────────────────────────────────────────────
-
-async def _stream(*pairs: tuple[str, str]):
- """Async generator that yields (agent_name, token) pairs."""
- for pair in pairs:
- yield pair
+async def _stream(*events: tuple[str, object]):
+ for event in events:
+ yield event
-async def collect(formatter, token_stream):
+async def _collect(formatter: StreamFormatter, event_stream):
frames = []
- async for frame in formatter.format(token_stream):
+ async for frame in formatter.format(event_stream):
frames.append(frame)
return frames
-# ── HomeFormatter ─────────────────────────────────────────────────────────────
-
@pytest.mark.asyncio
-async def test_home_formatter_text_block():
- req_id = "req-1"
- tokens = [
- ("task_agent", '{"type": "text", "content": "Hello world"}'),
- ]
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(*tokens))
-
- assert isinstance(frames[0], WsStreamStart)
- assert frames[0].request_id == req_id
- text_frames = [f for f in frames if isinstance(f, WsStreamText)]
- assert any("Hello world" in f.chunk for f in text_frames)
- assert isinstance(frames[-1], WsStreamEnd)
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_chart_block():
- req_id = "req-2"
- chart_json = (
- '{"type": "chart", "chartType": "bar", '
- '"title": "Tasks", "data": [{"x": 1}], '
- '"config": {"x": {"label": "X", "color": "#fff"}}}'
+async def test_stream_formatter_text_stream() -> None:
+ formatter = StreamFormatter(request_id="req-1")
+ frames = await _collect(
+ formatter,
+ _stream(("token", "Hello"), ("token", " world")),
)
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(("task_agent", chart_json)))
- block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
- assert len(block_frames) == 1
- assert block_frames[0].block_type == "chart"
- assert block_frames[0].data["chartType"] == "bar"
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_invalid_chart_skipped():
- req_id = "req-3"
- bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(("task_agent", bad_chart)))
-
- block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
- assert len(block_frames) == 0 # invalid chart skipped
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_entity_ref_resolved():
- req_id = "req-4"
- tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
- entity_json = '{"type": "entity_ref", "entity": "task"}'
- formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
- frames = await collect(formatter, _stream(("task_agent", entity_json)))
-
- block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
- assert len(block_frames) == 1
- assert block_frames[0].data["entity"] == "task"
- assert block_frames[0].data["items"][0]["id"] == "t1"
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_entity_ref_missing_skipped():
- req_id = "req-5"
- entity_json = '{"type": "entity_ref", "entity": "task"}'
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(("task_agent", entity_json)))
-
- block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
- assert len(block_frames) == 0 # no tool results → skipped
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_table_block():
- req_id = "req-6"
- table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(("task_agent", table_json)))
-
- block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
- assert len(block_frames) == 1
- assert block_frames[0].block_type == "table"
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_timeline_block():
- req_id = "req-7"
- timeline_json = '{"type": "timeline", "timelines": [{"id": "c1", "title": "M1", "date": 123}]}'
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(("task_agent", timeline_json)))
-
- block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
- assert len(block_frames) == 1
- assert block_frames[0].block_type == "timeline"
-
-
-@pytest.mark.asyncio
-async def test_home_formatter_frame_order():
- """stream_start is first, stream_end is last."""
- req_id = "req-8"
- formatter = HomeFormatter(request_id=req_id, tool_results=[])
- frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
assert isinstance(frames[0], WsStreamStart)
+ assert isinstance(frames[1], WsStreamText)
+ assert frames[1].chunk == "Hello"
+ assert isinstance(frames[2], WsStreamText)
+ assert frames[2].chunk == " world"
assert isinstance(frames[-1], WsStreamEnd)
-# ── FloatingFormatter ────────────────────────────────────────────────────────────
-
@pytest.mark.asyncio
-async def test_floating_formatter_domain_emitted_first():
- req_id = "pop-1"
- formatter = FloatingFormatter(request_id=req_id)
- tokens = [
- ("task_agent", ""), # domain signal
- ("task_agent", "Hello"),
- ("task_agent", " there"),
- ]
- frames = await collect(formatter, _stream(*tokens))
+async def test_stream_formatter_floating_domain_first() -> None:
+ formatter = StreamFormatter(request_id="req-2")
+ frames = await _collect(
+ formatter,
+ _stream(("floating_domain", "notes"), ("token", "Summary")),
+ )
assert isinstance(frames[0], WsFloatingDomain)
- assert frames[0].domain == "tasks"
- assert frames[0].request_id == req_id
+ assert frames[0].domain == "notes"
+ assert isinstance(frames[1], WsStreamStart)
+ assert isinstance(frames[2], WsStreamText)
+ assert frames[2].chunk == "Summary"
+ assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
-async def test_floating_formatter_text_only():
- req_id = "pop-2"
- formatter = FloatingFormatter(request_id=req_id)
- tokens = [("timeline_agent", ""), ("timeline_agent", "Summary")]
- frames = await collect(formatter, _stream(*tokens))
+async def test_stream_formatter_ignores_unknown_events() -> None:
+ formatter = StreamFormatter(request_id="req-3")
+ frames = await _collect(
+ formatter,
+ _stream(("tool_end", {"name": "x"}), ("token", "ok")),
+ )
- assert isinstance(frames[0], WsFloatingDomain)
- assert frames[0].domain == "timelines"
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert len(text_frames) == 1
- assert text_frames[0].chunk == "Summary"
+ assert text_frames[0].chunk == "ok"
@pytest.mark.asyncio
-async def test_floating_formatter_no_block_frames():
- """FloatingFormatter must never emit WsStreamBlock."""
- req_id = "pop-3"
- formatter = FloatingFormatter(request_id=req_id)
- tokens = [
- ("note_agent", ""),
- ("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
- ]
- frames = await collect(formatter, _stream(*tokens))
- assert not any(isinstance(f, WsStreamBlock) for f in frames)
+async def test_stream_formatter_empty_stream_still_brackets() -> None:
+ formatter = StreamFormatter(request_id="req-4")
+ frames = await _collect(formatter, _stream())
-
-@pytest.mark.asyncio
-async def test_floating_formatter_end_frame():
- req_id = "pop-4"
- formatter = FloatingFormatter(request_id=req_id)
- frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
- assert isinstance(frames[-1], WsStreamEnd)
-
-
-@pytest.mark.asyncio
-async def test_floating_formatter_unknown_agent_defaults_to_tasks():
- req_id = "pop-5"
- formatter = FloatingFormatter(request_id=req_id)
- frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
- assert frames[0].domain == "tasks"
+ assert len(frames) == 2
+ assert isinstance(frames[0], WsStreamStart)
+ assert isinstance(frames[1], WsStreamEnd)
diff --git a/tests/test_schemas_v3.py b/tests/test_schemas_v3.py
index 054c9d3..16dc611 100644
--- a/tests/test_schemas_v3.py
+++ b/tests/test_schemas_v3.py
@@ -9,7 +9,6 @@ from app.schemas import (
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
- WsStreamBlock,
WsStreamEnd,
WsStreamStart,
WsStreamText,
@@ -25,7 +24,6 @@ def test_v3_frame_types_exist():
"floating_request",
"stream_start",
"stream_text",
- "stream_block",
"stream_end",
"floating_domain",
"data_request",
@@ -174,89 +172,21 @@ def test_stream_text_deserializes():
assert frame.chunk == "test"
-# ── WsStreamBlock ─────────────────────────────────────────────────────
-
-
-def test_stream_block_chart():
- data = {
- "type": "chart",
- "chartType": "bar",
- "title": "Tasks",
- "data": [{"name": "Done", "count": 5}],
- "config": {"count": {"label": "Count", "color": "#4f46e5"}},
- }
- frame = WsStreamBlock(request_id="r1", block_type="chart", data=data)
- assert frame.type == WsFrameType.stream_block
- assert frame.block_type == "chart"
- assert frame.data["chartType"] == "bar"
-
-
-def test_stream_block_entity_ref():
- frame = WsStreamBlock(
- request_id="r1",
- block_type="entity_ref",
- data={"type": "task", "id": "t-1", "title": "Fix bug"},
- )
- assert frame.block_type == "entity_ref"
-
-
-def test_stream_block_table():
- frame = WsStreamBlock(
- request_id="r1",
- block_type="table",
- data={"headers": ["A", "B"], "rows": [["1", "2"]]},
- )
- assert frame.block_type == "table"
-
-
-def test_stream_block_timeline():
- frame = WsStreamBlock(
- request_id="r1",
- block_type="timeline",
- data={"timelines": [{"id": "c1", "title": "Launch", "date": 1700000000}]},
- )
- assert frame.block_type == "timeline"
-
-
-def test_stream_block_invalid_type():
- with pytest.raises(ValidationError):
- WsStreamBlock(
- request_id="r1",
- block_type="unknown", # type: ignore[arg-type]
- data={},
- )
-
-
-def test_stream_block_serializes():
- frame = WsStreamBlock(request_id="r1", block_type="table", data={"headers": [], "rows": []})
- d = frame.model_dump()
- assert d["type"] == "stream_block"
- assert d["block_type"] == "table"
-
-
# ── WsStreamEnd ───────────────────────────────────────────────────────
def test_stream_end_defaults():
frame = WsStreamEnd(request_id="r1")
assert frame.type == WsFrameType.stream_end
- assert frame.mutations == []
-
-
-def test_stream_end_with_mutations():
- mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
- frame = WsStreamEnd(request_id="r1", mutations=mutations)
- assert len(frame.mutations) == 1
- assert frame.mutations[0]["action"] == "create"
def test_stream_end_serializes():
data = WsStreamEnd(request_id="r2").model_dump()
- assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
+ assert data == {"type": "stream_end", "request_id": "r2"}
def test_stream_end_deserializes():
- raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
+ raw = {"type": "stream_end", "request_id": "r3"}
frame = WsStreamEnd.model_validate(raw)
assert frame.request_id == "r3"
diff --git a/tests/test_ws_unified.py b/tests/test_ws_unified.py
index f4e6387..41fd689 100644
--- a/tests/test_ws_unified.py
+++ b/tests/test_ws_unified.py
@@ -45,14 +45,13 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
return frames
-async def _mock_home_stream(user_id, message, context, reg=None):
- yield "task_agent", ""
- yield "task_agent", '{"type": "text", "content": "Hello"}'
+async def _mock_home_stream(user_id, message, context):
+ yield "token", "Hello"
-async def _mock_floating_stream(user_id, message, context, reg=None):
- yield "task_agent", ""
- yield "task_agent", "Here is a summary"
+async def _mock_floating_stream(user_id, message, context):
+ yield "floating_domain", "tasks"
+ yield "token", "Here is a summary"
# ── tests ─────────────────────────────────────────────────────────────────────
@@ -61,7 +60,7 @@ def test_home_request_produces_stream_frames(client):
"""home_request → stream_start, stream_text+, stream_end."""
token = make_jwt("power", user_id=USER_ID)
- with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
+ with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
@@ -84,7 +83,7 @@ def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
- with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
+ with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
@@ -112,11 +111,10 @@ def test_home_request_request_id_propagated(client):
token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id"
- async def _stream(user_id, message, context, reg=None):
- yield "note_agent", ""
- yield "note_agent", '{"type": "text", "content": "ok"}'
+ async def _stream(user_id, message, context):
+ yield "token", "ok"
- with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
+ with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": []