From cfc9d7a9421a201710e26f08c37fa779ec77a394 Mon Sep 17 00:00:00 2001 From: roberto Date: Wed, 11 Mar 2026 17:50:22 +0100 Subject: [PATCH] refactor: replace orchestrator with LangGraph deep-agent supervisors - Add app/core/deep_agent.py with Home and Floating supervisor graphs using LangGraph create_react_agent (hierarchical pattern) - Strip ChatAgent classes from all 4 agent files, keep @tool functions - Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream - Update device_ws.py to use run_home_stream/run_floating_stream - Rewrite chat.py REST route to use run_home - Add update_core_memory tool to both supervisors - Add langgraph>=0.3.0 to requirements.txt - Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py - Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas - Update all affected tests to match new API - Remove 6 deprecated test files for deleted modules - Clean up stale docstrings referencing removed orchestrator --- app/agents/__init__.py | 2 +- app/agents/note_agent.py | 40 +- app/agents/project_agent.py | 48 +- app/agents/task_agent.py | 54 +-- app/agents/timeline_agent.py | 38 +- app/api/routes/chat.py | 27 +- app/api/routes/device_ws.py | 43 +- app/api/routes/plans.py | 37 -- app/core/agent_registry.py | 217 --------- app/core/agent_runner.py | 2 +- app/core/deep_agent.py | 429 ++++++++++++++++++ app/core/execution_plan.py | 222 ---------- app/core/llm.py | 2 +- app/core/memory_middleware.py | 4 +- app/core/orchestrator.py | 210 --------- app/core/output_formatter.py | 293 +++++------- app/core/ws_context.py | 12 +- app/main.py | 8 +- app/schemas.py | 28 -- requirements.txt | 1 + tests/test_agent_registry.py | 214 --------- tests/test_agent_streaming.py | 416 ----------------- tests/test_agents.py | 761 -------------------------------- tests/test_execution_plan.py | 286 ------------ tests/test_memory_middleware.py | 8 +- tests/test_middleware.py | 9 +- tests/test_orchestrator.py | 347 --------------- tests/test_orchestrator_v3.py | 236 ---------- tests/test_output_formatter.py | 198 +++++---- tests/test_plugins.py | 2 +- tests/test_ws_unified.py | 27 +- 31 files changed, 723 insertions(+), 3498 deletions(-) delete mode 100644 app/api/routes/plans.py delete mode 100644 app/core/agent_registry.py create mode 100644 app/core/deep_agent.py delete mode 100644 app/core/execution_plan.py delete mode 100644 app/core/orchestrator.py delete mode 100644 tests/test_agent_registry.py delete mode 100644 tests/test_agent_streaming.py delete mode 100644 tests/test_agents.py delete mode 100644 tests/test_execution_plan.py delete mode 100644 tests/test_orchestrator.py delete mode 100644 tests/test_orchestrator_v3.py diff --git a/app/agents/__init__.py b/app/agents/__init__.py index 6a202c1..a90239c 100644 --- a/app/agents/__init__.py +++ b/app/agents/__init__.py @@ -1,4 +1,4 @@ -"""Import all agent modules to trigger @registry.register decorators.""" +"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs.""" from app.agents import timeline_agent, note_agent, project_agent, task_agent diff --git a/app/agents/note_agent.py b/app/agents/note_agent.py index e5c648a..3cd08d1 100644 --- a/app/agents/note_agent.py +++ b/app/agents/note_agent.py @@ -1,31 +1,14 @@ -"""Note agent — Markdown note management (list, get, create, update, delete).""" +"""Note agent — tool definitions for Markdown note CRUD.""" from __future__ import annotations -import json from typing import Any -from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool -from app.core.agent_registry import ChatAgent, registry -from app.core.llm import embed, get_llm +from app.core.llm import embed from app.core.ws_context import execute_on_client -_SYSTEM_PROMPT = ( - "You are a note-taking assistant. You help users create, retrieve, update,\n" - "and delete Markdown notes in their workspace.\n\n" - "Rules:\n" - " - content is always Markdown; preserve formatting when updating\n" - " - project_id is optional; link a note to a project when mentioned\n" - " - When updating, call get_note first if you need to read existing content\n" - " before appending or replacing sections\n" - " - list_notes without project_id returns all notes; scope with project_id\n" - " when the user is working within a specific project\n" - " - Do not fabricate note content — reflect what the user provides or what\n" - " is already in the note (retrieved via get_note)." -) - @tool async def list_notes(project_id: str = "") -> str: @@ -122,23 +105,4 @@ async def delete_note(note_id: str) -> str: return f"Note {note_id} deleted." -@registry.register -class NoteAgent(ChatAgent): - def get_name(self) -> str: - return "note_agent" - def get_description(self) -> str: - return "Manages notes: list, get, create, update, delete" - - def get_tools(self) -> list[Any]: - return [list_notes, get_note, create_note, update_note, delete_note] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = get_llm() - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/project_agent.py b/app/agents/project_agent.py index ccd2ea6..a30f157 100644 --- a/app/agents/project_agent.py +++ b/app/agents/project_agent.py @@ -1,33 +1,13 @@ -"""Project agent — full lifecycle management (list, get, create, update, archive, delete).""" +"""Project agent — tool definitions for project lifecycle CRUD.""" from __future__ import annotations -import json from typing import Any -from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool -from app.core.agent_registry import ChatAgent, registry -from app.core.llm import get_llm from app.core.ws_context import execute_on_client -_SYSTEM_PROMPT = ( - "You are a project management assistant. You help users create, find,\n" - "update, and archive projects in their workspace.\n\n" - "Rules:\n" - " - status must be one of: active, archived\n" - " - client_id is optional; link to a client only when explicitly mentioned\n" - " - ai_summary is populated only when the user asks for a project summary;\n" - " derive it from context data — do not fabricate content\n" - " - Use list_projects for scoped queries; list_all_projects only when the\n" - " user wants a complete cross-client view including archived projects\n" - " - get_project requires a project UUID; resolve the ID first by calling\n" - " list_projects if you only have a project name\n" - " - Prefer archiving (update_project status=archived) over deletion;\n" - " only call delete_project when the user explicitly confirms deletion." -) - @tool async def list_projects( @@ -137,30 +117,4 @@ async def delete_project(project_id: str) -> str: return f"Project {project_id} permanently deleted." -@registry.register -class ProjectAgent(ChatAgent): - def get_name(self) -> str: - return "project_agent" - def get_description(self) -> str: - return "Manages projects: list, get, create, update, archive, delete" - - def get_tools(self) -> list[Any]: - return [ - list_projects, - list_all_projects, - get_project, - create_project, - update_project, - delete_project, - ] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = get_llm() - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/task_agent.py b/app/agents/task_agent.py index 1d6e32d..5421d49 100644 --- a/app/agents/task_agent.py +++ b/app/agents/task_agent.py @@ -1,35 +1,14 @@ -"""Task agent — full CRUD for tasks and task comments.""" +"""Task agent — tool definitions for task and task comment CRUD.""" from __future__ import annotations -import json from datetime import datetime, timezone from typing import Any -from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool -from app.core.agent_registry import ChatAgent, registry -from app.core.llm import get_llm from app.core.ws_context import execute_on_client -_SYSTEM_PROMPT = ( - "You are a task management assistant for a project workspace.\n" - "You create, update, list, and track tasks and their comments.\n\n" - "Rules:\n" - " - status must be one of: todo, in_progress, done\n" - " - priority must be one of: high, medium, low\n" - " - due_date is a Unix timestamp in milliseconds; convert human dates\n" - " - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n" - " - project_id is optional; link to a project when the user mentions one\n" - " - is_ai_suggested: 1 only when proactively proposing a task the user\n" - " did not explicitly request; 0 otherwise\n" - " - is_approved defaults to 0; set to 1 only when the user confirms\n" - " - Use list_tasks_due_today for 'what's due today' queries\n" - " - For update_task, use -1 for integer fields you do not want to change\n" - " - Always confirm the action in plain, user-friendly language." -) - # ── Task tools ──────────────────────────────────────────────────────── @@ -220,35 +199,4 @@ async def delete_task_comment(comment_id: str) -> str: return f"Comment {comment_id} deleted." -# ── Agent ───────────────────────────────────────────────────────────── - -@registry.register -class TaskAgent(ChatAgent): - def get_name(self) -> str: - return "task_agent" - - def get_description(self) -> str: - return "Manages tasks and comments: list, create, update, delete, due-today, comments" - - def get_tools(self) -> list[Any]: - return [ - list_tasks, - create_task, - update_task, - delete_task, - list_tasks_due_today, - list_task_comments, - add_task_comment, - delete_task_comment, - ] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = get_llm() - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/timeline_agent.py b/app/agents/timeline_agent.py index 6e85357..a790aa7 100644 --- a/app/agents/timeline_agent.py +++ b/app/agents/timeline_agent.py @@ -1,30 +1,13 @@ -"""Timeline agent — project milestone management (list, create, update, delete).""" +"""Timeline agent — tool definitions for project milestone CRUD.""" from __future__ import annotations -import json from typing import Any -from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool -from app.core.agent_registry import ChatAgent, registry -from app.core.llm import get_llm from app.core.ws_context import execute_on_client -_SYSTEM_PROMPT = ( - "You are a project timeline assistant. Timelines are milestone dates that\n" - "track progress on a project — they are not calendar events.\n\n" - "Rules:\n" - " - project_id is REQUIRED for every create; confirm with the user if unknown\n" - " - date is a Unix timestamp in milliseconds; convert human-readable dates\n" - " - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n" - " - is_approved: 0 until the user explicitly confirms; then 1\n" - " - For update_timeline, use -1 for integer fields you do not want to change\n" - " - Listing without a project_id returns all timelines across projects\n" - " - Always echo the title and formatted date in your confirmation." -) - @tool async def list_timelines(project_id: str = "") -> str: @@ -106,23 +89,4 @@ async def delete_timeline(timeline_id: str) -> str: return f"Timeline {timeline_id} deleted." -@registry.register -class TimelineAgent(ChatAgent): - def get_name(self) -> str: - return "timeline_agent" - def get_description(self) -> str: - return "Manages project timelines (milestones): list, create, update, delete" - - def get_tools(self) -> list[Any]: - return [list_timelines, create_timeline, update_timeline, delete_timeline] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = get_llm() - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/api/routes/chat.py b/app/api/routes/chat.py index 1cd0fa4..9400e34 100644 --- a/app/api/routes/chat.py +++ b/app/api/routes/chat.py @@ -9,8 +9,10 @@ from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse from app.api.deps import get_current_user -from app.core.orchestrator import orchestrate -from app.schemas import ChatRequest, UserProfile +from app.core.deep_agent import run_home +from app.core.memory_middleware import MemoryMiddleware +from app.db import async_session +from app.schemas import ChatRequest, ChatResponse, UserProfile router = APIRouter(prefix="/chat", tags=["chat"]) @@ -20,10 +22,21 @@ async def chat( body: ChatRequest, current_user: UserProfile = Depends(get_current_user), ) -> JSONResponse: - """Route a chat message through the orchestrator. + """Route a chat message through the Home deep agent (non-streaming).""" + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context(current_user.id, body.message) - Returns ``ChatResponse`` for ``execution_mode='direct'``, - or ``ExecutionPlan`` for ``execution_mode='plan'``. - """ - result = await orchestrate(body) + context = { + **body.context.model_dump(), + **memory_context, + } + + response_text = await run_home( + user_id=current_user.id, + message=body.message, + context=context, + db_session_factory=async_session, + ) + result = ChatResponse(response=response_text) return JSONResponse(content=result.model_dump()) diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index 771b696..fa3611c 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -43,7 +43,7 @@ from app.config.settings import settings from app.core.agent_runner import trigger_pending_runs from app.core.device_manager import device_manager from app.core.memory_middleware import MemoryMiddleware -from app.core.orchestrator import orchestrate_v3_stream +from app.core.deep_agent import run_home_stream, run_floating_stream from app.core.output_formatter import HomeFormatter, FloatingFormatter from app.core.ws_context import clear_client_executor, set_client_executor from app.db import async_session @@ -204,9 +204,17 @@ async def _make_ws_executor(websocket: WebSocket, user_id: str): """Return a callback that sends tool_call frames and awaits tool_result.""" async def _executor(payload: dict) -> dict: payload["type"] = WsFrameType.tool_call + call_id = payload["id"] + logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action")) await websocket.send_text(json.dumps(payload)) - future = device_manager.create_pending_call(user_id, payload["id"]) - return await future + future = device_manager.create_pending_call(user_id, call_id) + result = await future + logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s", + call_id, type(result).__name__, + list(result.keys()) if isinstance(result, dict) else "N/A") + if result is None: + logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id) + return result return _executor @@ -233,21 +241,13 @@ async def _handle_home_request( executor = await _make_ws_executor(websocket, user_id) set_client_executor(executor) response_chunks: list[str] = [] - agent_holder: list = [] try: - token_stream = orchestrate_v3_stream( - user_id, message, context, agent_holder=agent_holder + event_stream = run_home_stream( + user_id, message, context, db_session_factory=async_session ) - formatter = HomeFormatter(request_id=request_id, tool_results=[]) - async for ws_frame in formatter.format(token_stream): - # Inject mutations from agent tool_results into stream_end - if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr] - ws_frame.mutations = [ # type: ignore[union-attr] - {"action": r["action"], "table": r["table"], "data": r["data"]} - for r in getattr(agent_holder[0], "tool_results", []) - ] + formatter = HomeFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): await websocket.send_text(ws_frame.model_dump_json()) - # Collect text chunks to build the full response for episode storage if ws_frame.type == "stream_text": # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] except Exception as exc: @@ -287,18 +287,13 @@ async def _handle_floating_request( executor = await _make_ws_executor(websocket, user_id) set_client_executor(executor) response_chunks: list[str] = [] - agent_holder: list = [] try: - token_stream = orchestrate_v3_stream( - user_id, message, context, agent_holder=agent_holder + event_stream = run_floating_stream( + user_id, message, context, scope=scope, + db_session_factory=async_session, ) formatter = FloatingFormatter(request_id=request_id) - async for ws_frame in formatter.format(token_stream): - if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr] - ws_frame.mutations = [ # type: ignore[union-attr] - {"action": r["action"], "table": r["table"], "data": r["data"]} - for r in getattr(agent_holder[0], "tool_results", []) - ] + async for ws_frame in formatter.format(event_stream): await websocket.send_text(ws_frame.model_dump_json()) if ws_frame.type == "stream_text": # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] diff --git a/app/api/routes/plans.py b/app/api/routes/plans.py deleted file mode 100644 index ed27272..0000000 --- a/app/api/routes/plans.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}.""" - -from __future__ import annotations - -from fastapi import APIRouter, Depends, HTTPException, status - -from app.api.deps import get_current_user -from app.core.execution_plan import plan_cache -from app.schemas import ExecutionPlan, UserProfile - -router = APIRouter(prefix="/plans", tags=["plans"]) - - -@router.get("/playbook", response_model=list[ExecutionPlan]) -async def list_playbooks( - current_user: UserProfile = Depends(get_current_user), -) -> list[ExecutionPlan]: - """Return all cached execution plan playbooks for the authenticated user. - - TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature. - """ - return plan_cache.get_all_playbooks() - - -@router.get("/playbook/{plan_id}", response_model=ExecutionPlan) -async def get_playbook( - plan_id: str, - current_user: UserProfile = Depends(get_current_user), -) -> ExecutionPlan: - """Return a specific execution plan playbook by ID.""" - plan = plan_cache.get_plan(plan_id) - if plan is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Plan not found: {plan_id}", - ) - return plan diff --git a/app/core/agent_registry.py b/app/core/agent_registry.py deleted file mode 100644 index 9a4930d..0000000 --- a/app/core/agent_registry.py +++ /dev/null @@ -1,217 +0,0 @@ -"""Agent Registry — base classes and singleton registry for chat agents.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from typing import Any - - -class BaseAgent(ABC): - """Common base for all agents.""" - - def __init__( - self, - user_id: str = "", - shared_memory: dict[str, Any] | None = None, - vector_store_context: list[str] | None = None, - ) -> None: - self.user_id = user_id - self.shared_memory: dict[str, Any] = shared_memory or {} - self.vector_store_context: list[str] = vector_store_context or [] - - @abstractmethod - def get_name(self) -> str: ... - - @abstractmethod - def get_description(self) -> str: ... - - @property - def skills(self) -> list[str]: - """Override in subclasses to advertise capabilities.""" - return [] - - -class ChatAgent(BaseAgent): - """Base class for LLM-powered chat agents.""" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - # Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results. - self.tool_results: list[dict] = [] - - @abstractmethod - async def handle(self, query: str, context: dict[str, Any]) -> str: - """Process a user query and return a text response.""" - ... - - async def handle_stream( - self, query: str, context: dict[str, Any] - ) -> AsyncGenerator[str, None]: - """Streaming variant of handle(). - - Default: calls handle() and yields the full response as one chunk. - Override in subclasses for true token-level streaming via _tool_loop_stream. - """ - yield await self.handle(query, context) - - @abstractmethod - def get_tools(self) -> list[Any]: - """Return LangChain tool definitions available to this agent.""" - ... - - async def _tool_loop( - self, - llm: Any, - messages: list[Any], - tools: list[Any], - max_iter: int = 5, - ) -> str: - """Shared tool-calling loop. - - Binds *tools* to *llm*, invokes iteratively until the model stops - requesting tool calls or *max_iter* is reached, and returns the - final text response. Captures raw execute_on_client results in - ``self.tool_results``. - """ - from langchain_core.messages import AIMessage, ToolMessage - - from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector - - collector: list[dict] = [] - set_tool_result_collector(collector) - try: - llm_with_tools = llm.bind_tools(tools) if tools else llm - - for _ in range(max_iter): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) - - if not response.tool_calls: - return str(response.content) - - # Execute each requested tool call - tool_map = {t.name: t for t in tools} - for call in response.tool_calls: - tool_fn = tool_map.get(call["name"]) - if tool_fn is None: - result = f"Unknown tool: {call['name']}" - else: - result = await tool_fn.ainvoke(call["args"]) - messages.append( - ToolMessage(content=str(result), tool_call_id=call["id"]) - ) - - # Exhausted iterations — ask model for a final answer without tools - response = await llm.ainvoke(messages) - return str(response.content) - finally: - clear_tool_result_collector() - self.tool_results = collector - - async def _tool_loop_stream( - self, - llm: Any, - messages: list[Any], - tools: list[Any], - max_iter: int = 5, - ) -> AsyncGenerator[str, None]: - """Streaming variant of ``_tool_loop``. - - Behaves identically for tool-calling iterations (uses ainvoke to parse - tool calls). For the final response — when the model produces no further - tool calls — switches to ``llm.astream()`` and yields text tokens. - Captures raw execute_on_client results in ``self.tool_results``. - """ - from langchain_core.messages import AIMessage, ToolMessage - - from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector - - collector: list[dict] = [] - set_tool_result_collector(collector) - try: - llm_with_tools = llm.bind_tools(tools) if tools else llm - - for _ in range(max_iter): - response: AIMessage = await llm_with_tools.ainvoke(messages) - - if not response.tool_calls: - # Stream the final answer — don't keep the ainvoke result. - async for chunk in llm.astream(messages): - if chunk.content: - yield str(chunk.content) - return - - messages.append(response) - - # Execute each requested tool call - tool_map = {t.name: t for t in tools} - for call in response.tool_calls: - tool_fn = tool_map.get(call["name"]) - if tool_fn is None: - result = f"Unknown tool: {call['name']}" - else: - result = await tool_fn.ainvoke(call["args"]) - messages.append( - ToolMessage(content=str(result), tool_call_id=call["id"]) - ) - - # Exhausted iterations — stream a final answer without tools - async for chunk in llm.astream(messages): - if chunk.content: - yield str(chunk.content) - finally: - clear_tool_result_collector() - self.tool_results = collector - - -class AgentRegistry: - """Singleton registry for ChatAgent subclasses.""" - - _instance: AgentRegistry | None = None - - def __init__(self) -> None: - self._agents: dict[str, type[ChatAgent]] = {} - - def __new__(cls) -> AgentRegistry: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._agents = {} - return cls._instance - - # ── public API ─────────────────────────────────────────────────── - - def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]: - """Class decorator — registers an agent by its name.""" - instance = agent_class() - name = instance.get_name() - self._agents[name] = agent_class - return agent_class - - def get(self, name: str) -> ChatAgent: - """Return a fresh instance of the named agent.""" - cls = self._agents.get(name) - if cls is None: - raise KeyError(f"Agent not found: {name}") - return cls() - - def list_agents(self) -> list[dict[str, str]]: - """Return ``[{name, description}]`` for the orchestrator prompt.""" - result: list[dict[str, str]] = [] - for cls in self._agents.values(): - inst = cls() - result.append( - {"name": inst.get_name(), "description": inst.get_description()} - ) - return result - - async def call_agent( - self, name: str, query: str, context: dict[str, Any] - ) -> str: - """Instantiate the named agent and call its ``handle`` method.""" - agent = self.get(name) - return await agent.handle(query, context) - - -# Module-level singleton -registry = AgentRegistry() diff --git a/app/core/agent_runner.py b/app/core/agent_runner.py index 0d25f65..4d8c976 100644 --- a/app/core/agent_runner.py +++ b/app/core/agent_runner.py @@ -1,4 +1,4 @@ -"""Agent run orchestrator. +"""Agent run manager. Drives two agent types: diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py new file mode 100644 index 0000000..45bdea7 --- /dev/null +++ b/app/core/deep_agent.py @@ -0,0 +1,429 @@ +"""Deep Agent — LangGraph hierarchical supervisors for home and floating modes. + +Two supervisor graphs (both ``create_react_agent``): + * **HomeSupervisor** — gathers data from multiple domains, presents + structured overview with tool-result blocks. + * **FloatingSupervisor** — focused, scoped assistant for a single entity/domain. + +Each supervisor delegates to four sub-agent tools, each a compiled +``create_react_agent`` wrapping the domain CRUD tools (task, project, note, +timeline). The sub-agents talk to Electron via ``execute_on_client``. + +Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that +callers can sniff: + * ``("messages", (token, metadata))`` — text tokens for streaming + * ``("updates", ...)`` — tool call results for mutations + +An ``update_core_memory`` tool is available to both supervisors for +persisting user preferences mid-conversation (MemGPT-style). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, AsyncGenerator + +from langchain_core.messages import AIMessageChunk, HumanMessage +from langchain_core.tools import tool +from langgraph.prebuilt import create_react_agent + +from app.core.llm import get_llm +from app.core.ws_context import ( + clear_tool_result_collector, + set_tool_result_collector, +) + +logger = logging.getLogger(__name__) + +# ── Sub-agent tool imports ──────────────────────────────────────────── + +from app.agents.task_agent import ( # noqa: E402 + add_task_comment, + create_task, + delete_task, + delete_task_comment, + list_task_comments, + list_tasks, + list_tasks_due_today, + update_task, +) +from app.agents.note_agent import ( # noqa: E402 + create_note, + delete_note, + get_note, + list_notes, + update_note, +) +from app.agents.project_agent import ( # noqa: E402 + create_project, + delete_project, + get_project, + list_all_projects, + list_projects, + update_project, +) +from app.agents.timeline_agent import ( # noqa: E402 + create_timeline, + delete_timeline, + list_timelines, + update_timeline, +) + +# ── Sub-agent definitions ───────────────────────────────────────────── + +_TASK_TOOLS = [ + list_tasks, + create_task, + update_task, + delete_task, + list_tasks_due_today, + list_task_comments, + add_task_comment, + delete_task_comment, +] + +_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note] + +_PROJECT_TOOLS = [ + list_projects, + list_all_projects, + get_project, + create_project, + update_project, + delete_project, +] + +_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline] + + +def _build_subagent_tool( + name: str, + description: str, + system_prompt: str, + tools: list, +): + """Build a compiled sub-agent graph and wrap it as a LangChain tool.""" + subgraph = create_react_agent( + model=get_llm(), + tools=tools, + prompt=system_prompt, + name=name, + ) + + @tool(name=name, description=description) + async def _run(query: str) -> str: + result = await subgraph.ainvoke( + {"messages": [HumanMessage(content=query)]} + ) + messages = result["messages"] + # Return the last AI message content + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None): + return str(msg.content) + return "No response from sub-agent." + + return _run + + +def _make_subagent_tools() -> list: + """Create the four sub-agent tools for the supervisor.""" + return [ + _build_subagent_tool( + name="task_agent", + description=( + "Manages tasks and comments: list, create, update, delete, " + "due-today, comments. Delegate task-related queries here." + ), + system_prompt=( + "You are a task management assistant. You create, update, list, " + "and track tasks and their comments.\n\n" + "Rules:\n" + " - status must be one of: todo, in_progress, done\n" + " - priority must be one of: high, medium, low\n" + " - due_date is a Unix timestamp in milliseconds\n" + " - assignees is a JSON-encoded array of strings\n" + " - is_approved defaults to 0; set to 1 only when the user confirms\n" + " - For update_task, use -1 for integer fields you do not want to change\n" + " - Always confirm the action in plain, user-friendly language." + ), + tools=_TASK_TOOLS, + ), + _build_subagent_tool( + name="note_agent", + description=( + "Manages notes: list, get, create, update, delete. " + "Delegate note-related queries here." + ), + system_prompt=( + "You are a note-taking assistant. You help users create, retrieve, " + "update, and delete Markdown notes in their workspace.\n\n" + "Rules:\n" + " - content is always Markdown; preserve formatting when updating\n" + " - When updating, call get_note first if you need to read existing " + "content before appending or replacing sections\n" + " - Do not fabricate note content." + ), + tools=_NOTE_TOOLS, + ), + _build_subagent_tool( + name="project_agent", + description=( + "Manages projects: list, get, create, update, archive, delete. " + "Delegate project-related queries here." + ), + system_prompt=( + "You are a project management assistant. You help users create, " + "find, update, and archive projects.\n\n" + "Rules:\n" + " - status must be one of: active, archived\n" + " - Prefer archiving over deletion\n" + " - ai_summary is populated only when the user asks for a summary." + ), + tools=_PROJECT_TOOLS, + ), + _build_subagent_tool( + name="timeline_agent", + description=( + "Manages project timelines (milestones): list, create, update, " + "delete. Delegate timeline/milestone queries here." + ), + system_prompt=( + "You are a project timeline assistant. Timelines are milestone " + "dates that track progress on a project.\n\n" + "Rules:\n" + " - project_id is REQUIRED for every create\n" + " - date is a Unix timestamp in milliseconds\n" + " - For update_timeline, use -1 for integer fields you do not " + "want to change." + ), + tools=_TIMELINE_TOOLS, + ), + ] + + +# ── Update core memory tool ────────────────────────────────────────── + +def _make_update_core_memory_tool(user_id: str, db_session_factory): + """Create a tool that persists a key/value preference in core memory.""" + + @tool + async def update_core_memory(key: str, value: str) -> str: + """Save a user preference or fact to long-term core memory. + key: short label for the memory (e.g. 'preferred_language', 'timezone') + value: the value to remember + Use this when the user states a preference or fact worth remembering. + """ + from app.core.memory_middleware import MemoryMiddleware + + async with db_session_factory() as db: + memory = MemoryMiddleware(db) + await memory.update_core(user_id, key, value) + return f"Remembered: {key} = {value}" + + return update_core_memory + + +# ── System prompts ──────────────────────────────────────────────────── + +_HOME_SYSTEM = ( + "You are Adiuva, a smart workspace assistant on the Home dashboard.\n" + "Your job is to help the user by gathering data from their workspace and " + "presenting a comprehensive overview.\n\n" + "You have sub-agent tools (task_agent, note_agent, project_agent, " + "timeline_agent) that can query and mutate workspace data. Delegate to " + "the appropriate sub-agent(s) based on the user's request. You can call " + "multiple sub-agents if needed.\n\n" + "You also have an update_core_memory tool — use it when the user states " + "a preference or important fact worth remembering long-term.\n\n" + "After gathering data, synthesize a clear, helpful response for the user.\n\n" + "Memory context:\n{memory_context}" +) + +_FLOATING_SYSTEM = ( + "You are Adiuva, a focused workspace assistant in the floating panel.\n" + "The user is currently working in the '{scope_type}' section" + "{scope_detail}.\n\n" + "You have sub-agent tools (task_agent, note_agent, project_agent, " + "timeline_agent) that can query and mutate workspace data. Focus your " + "help on the user's current scope, but you can use other sub-agents " + "if the request requires it.\n\n" + "You also have an update_core_memory tool — use it when the user states " + "a preference or important fact worth remembering long-term.\n\n" + "Provide direct, conversational responses.\n\n" + "Memory context:\n{memory_context}" +) + + +def _format_memory_context(memory: dict[str, Any]) -> str: + """Format the memory dict into a readable string for the system prompt.""" + if not memory: + return "(no memory available)" + parts = [] + if memory.get("core_memory"): + parts.append("Preferences: " + json.dumps(memory["core_memory"])) + if memory.get("associative_memory"): + parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3])) + if memory.get("episodic_memory"): + parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3])) + if memory.get("proactive_hints"): + parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3])) + return "\n".join(parts) if parts else "(no memory available)" + + +# ── Graph builders ──────────────────────────────────────────────────── + +def build_home_graph( + user_id: str, + memory_context: dict[str, Any], + db_session_factory, +): + """Build the Home supervisor graph.""" + subagent_tools = _make_subagent_tools() + memory_tool = _make_update_core_memory_tool(user_id, db_session_factory) + all_tools = subagent_tools + [memory_tool] + + prompt = _HOME_SYSTEM.format( + memory_context=_format_memory_context(memory_context), + ) + + return create_react_agent( + model=get_llm(), + tools=all_tools, + prompt=prompt, + name="home_supervisor", + ) + + +def build_floating_graph( + user_id: str, + memory_context: dict[str, Any], + scope: dict[str, Any], + db_session_factory, +): + """Build the Floating supervisor graph.""" + subagent_tools = _make_subagent_tools() + memory_tool = _make_update_core_memory_tool(user_id, db_session_factory) + all_tools = subagent_tools + [memory_tool] + + scope_type = scope.get("type", "general") + scope_id = scope.get("id") + scope_detail = f" (id: {scope_id})" if scope_id else "" + + prompt = _FLOATING_SYSTEM.format( + scope_type=scope_type, + scope_detail=scope_detail, + memory_context=_format_memory_context(memory_context), + ) + + return create_react_agent( + model=get_llm(), + tools=all_tools, + prompt=prompt, + name="floating_supervisor", + ) + + +# ── Stream event type ──────────────────────────────────────────────── + +# Events yielded by run_*_stream: +# ("token", str) — text token for streaming +# ("tool_start", dict) — {"name": "task_agent", "args": {...}} +# ("tool_end", dict) — {"name": "task_agent", "result": "..."} + + +# ── Stream runners ──────────────────────────────────────────────────── + +async def _run_graph_stream( + graph, + message: str, +) -> AsyncGenerator[tuple[str, Any], None]: + """Run a supervisor graph with streaming, yielding event tuples. + + Uses ``stream_mode=["messages", "updates"]`` to get both token-level + streaming and update events for tool calls. + """ + inputs = {"messages": [HumanMessage(content=message)]} + + collector: list[dict] = [] + set_tool_result_collector(collector) + try: + async for stream_mode, chunk in graph.astream( + inputs, + stream_mode=["messages", "updates"], + ): + if stream_mode == "messages": + msg, metadata = chunk + # Only yield tokens from the supervisor's final response + # (not from sub-agent internal LLM calls) + if ( + isinstance(msg, AIMessageChunk) + and msg.content + and not msg.tool_calls + and metadata.get("langgraph_node") == "agent" + ): + yield ("token", str(msg.content)) + + elif stream_mode == "updates": + # Updates is a dict of {node_name: state_update} + if not isinstance(chunk, dict): + continue + for node_name, state_update in chunk.items(): + if node_name != "tools": + continue + # Tool node executed — extract tool call results + tool_messages = state_update.get("messages", []) + for tool_msg in tool_messages: + if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"): + yield ( + "tool_end", + {"name": tool_msg.name, "result": str(tool_msg.content)}, + ) + finally: + clear_tool_result_collector() + + # Yield the collected mutations so callers can attach them to stream_end + yield ("mutations", collector) + + +async def run_home_stream( + user_id: str, + message: str, + context: dict[str, Any], + db_session_factory, +) -> AsyncGenerator[tuple[str, Any], None]: + """Run the Home supervisor and yield streaming events.""" + graph = build_home_graph(user_id, context, db_session_factory) + async for event in _run_graph_stream(graph, message): + yield event + + +async def run_floating_stream( + user_id: str, + message: str, + context: dict[str, Any], + scope: dict[str, Any], + db_session_factory, +) -> AsyncGenerator[tuple[str, Any], None]: + """Run the Floating supervisor and yield streaming events.""" + graph = build_floating_graph(user_id, context, scope, db_session_factory) + async for event in _run_graph_stream(graph, message): + yield event + + +async def run_home( + user_id: str, + message: str, + context: dict[str, Any], + db_session_factory, +) -> str: + """Run the Home supervisor (non-streaming) and return full response text.""" + graph = build_home_graph(user_id, context, db_session_factory) + result = await graph.ainvoke( + {"messages": [HumanMessage(content=message)]} + ) + messages = result["messages"] + for msg in reversed(messages): + if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None): + return str(msg.content) + return "" diff --git a/app/core/execution_plan.py b/app/core/execution_plan.py deleted file mode 100644 index a98879f..0000000 --- a/app/core/execution_plan.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Execution Plan generator — builder, template registry, and LRU plan cache.""" - -from __future__ import annotations - -from collections import OrderedDict -from typing import Any - -from app.schemas import ExecutionPlan, PlanStep - - -# ── Prompt Template Registry ────────────────────────────────────────── - - -class PromptTemplateRegistry: - """Server-side store mapping template IDs to prompt text. - - Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``). - The actual prompt text is resolved here on the server, keeping prompt IP - out of API responses. - """ - - def __init__(self) -> None: - self._templates: dict[str, str] = {} - - def register(self, template_id: str, prompt_text: str) -> None: - self._templates[template_id] = prompt_text - - def get(self, template_id: str) -> str: - """Resolve a template ID to its prompt text. - - Raises ``KeyError`` if the template is not registered. - """ - text = self._templates.get(template_id) - if text is None: - raise KeyError(f"Template not found: {template_id!r}") - return text - - def has(self, template_id: str) -> bool: - return template_id in self._templates - - def list_ids(self) -> list[str]: - """Return all registered template IDs (never the text).""" - return list(self._templates.keys()) - - -# ── Execution Plan Builder ──────────────────────────────────────────── - - -class ExecutionPlanBuilder: - """Fluent builder for ``ExecutionPlan`` objects. - - Example:: - - plan = ( - ExecutionPlanBuilder("task_agent") - .add_llm_step("tpl_task_agent_default", {"message": user_msg}) - .add_data_step("create_record", data_from_step=0) - .build() - ) - """ - - def __init__(self, agent: str) -> None: - self._agent = agent - self._steps: list[PlanStep] = [] - - # ── step adders ────────────────────────────────────────────────── - - def add_step( - self, action: str, params: dict[str, Any] | None = None - ) -> ExecutionPlanBuilder: - """Append a generic action step with optional parameters.""" - self._steps.append(PlanStep(action=action, variables=params)) - return self - - def add_llm_step( - self, template_id: str, variables: dict[str, Any] | None = None - ) -> ExecutionPlanBuilder: - """Append an LLM step referencing a server-side template by ID.""" - self._steps.append( - PlanStep(action="llm", prompt_template=template_id, variables=variables) - ) - return self - - def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder: - """Append a step whose input comes from the output of an earlier step.""" - self._steps.append(PlanStep(action=action, data_from_step=data_from_step)) - return self - - # ── build ──────────────────────────────────────────────────────── - - def build(self) -> ExecutionPlan: - """Validate step references and return the ``ExecutionPlan``. - - Raises ``ValueError`` if any ``data_from_step`` references a - non-existent or future step index. - """ - for i, step in enumerate(self._steps): - if step.data_from_step is not None: - if not (0 <= step.data_from_step < i): - raise ValueError( - f"Step {i}: data_from_step={step.data_from_step} must " - f"reference a preceding step index in range 0..{i - 1}" - ) - return ExecutionPlan(agent=self._agent, steps=list(self._steps)) - - -# ── Plan Cache (LRU) ────────────────────────────────────────────────── - - -class PlanCache: - """In-memory LRU cache for ``ExecutionPlan`` objects. - - Plans stored here are accessible as playbooks via ``get_all_playbooks()``. - The cache also serves as a runtime memoisation layer so that repeated - identical intent classifications can skip re-building the plan. - """ - - def __init__(self, maxsize: int = 1000) -> None: - self._maxsize = maxsize - self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict() - - def cache_plan(self, key: str, plan: ExecutionPlan) -> None: - """Store *plan* under *key*, evicting the LRU entry if at capacity.""" - if key in self._cache: - del self._cache[key] # remove so re-insertion places it at the end - elif len(self._cache) >= self._maxsize: - self._cache.popitem(last=False) # evict least-recently-used - self._cache[key] = plan - - def get_plan(self, key: str) -> ExecutionPlan | None: - """Return the cached plan for *key*, or ``None`` if not present. - - Accessing a plan marks it as most-recently used. - """ - if key not in self._cache: - return None - self._cache.move_to_end(key) - return self._cache[key] - - def get_all_playbooks(self) -> list[ExecutionPlan]: - """Return all cached plans (most-recently used last).""" - return list(self._cache.values()) - - -# ── Module-level singletons ─────────────────────────────────────────── - -template_registry = PromptTemplateRegistry() -plan_cache = PlanCache() - - -def _register_builtin_templates() -> None: - """Register the built-in server-side prompt templates. - - These strings never leave the server. Clients only receive the IDs. - """ - _tpls: dict[str, str] = { - "tpl_task_agent_default": ( - "You are a task management assistant. Help the user create, update, " - "list, and track tasks. Use correct status values (todo, in_progress, " - "done) and priority values (high, medium, low) from the workspace model." - ), - "tpl_timeline_agent_default": ( - "You are a project timeline assistant. Help the user create and manage " - "milestone timelines on their projects. Every timeline requires a " - "project_id and a date expressed as a Unix timestamp in milliseconds." - ), - "tpl_project_agent_default": ( - "You are a project management assistant. Help the user create, find, " - "update, and archive projects. Projects have a name, an optional client, " - "and a status of either active or archived." - ), - "tpl_note_agent_default": ( - "You are a note-taking assistant. Help the user create, retrieve, update, " - "and delete Markdown notes. Notes can optionally be linked to a project." - ), - "tpl_task_extract_from_project": ( - "Extract all actionable tasks from the provided project context. " - "Return a structured list of tasks, each with a title, inferred priority " - "(high, medium, or low), suggested status (todo), and a due_date in " - "milliseconds where a deadline can be inferred." - ), - "tpl_note_weekly_summary": ( - "Generate a weekly project summary note from the provided workspace data. " - "Include: tasks completed this week, tasks due soon, active projects, " - "and upcoming timelines. Format the output as clean Markdown." - ), - } - for tid, text in _tpls.items(): - template_registry.register(tid, text) - - -def _load_playbooks() -> None: - """Pre-build and cache the built-in playbooks.""" - playbooks: list[tuple[str, ExecutionPlan]] = [ - ( - "create_tasks_from_project", - ExecutionPlanBuilder("project_agent") - .add_llm_step( - "tpl_task_extract_from_project", - {"source": "project_context"}, - ) - .add_data_step("create_record", data_from_step=0) - .build(), - ), - ( - "generate_weekly_note", - ExecutionPlanBuilder("note_agent") - .add_llm_step( - "tpl_note_weekly_summary", - {"period": "last_7_days"}, - ) - .add_data_step("create_record", data_from_step=0) - .build(), - ), - ] - for key, plan in playbooks: - plan_cache.cache_plan(key, plan) - - -# Initialise on module load -_register_builtin_templates() -_load_playbooks() diff --git a/app/core/llm.py b/app/core/llm.py index 3d985af..9669d4c 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -1,6 +1,6 @@ """LLM factory — centralised model instantiation via LiteLLM. -Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()`` +Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()`` instead of directly constructing a provider-specific class. The model string follows the `LiteLLM model naming convention `_: diff --git a/app/core/memory_middleware.py b/app/core/memory_middleware.py index 8053117..accaa37 100644 --- a/app/core/memory_middleware.py +++ b/app/core/memory_middleware.py @@ -43,7 +43,7 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6 class MemoryMiddleware: - """Enrich orchestrator context with memory and persist interactions after.""" + """Enrich agent context with memory and persist interactions after.""" def __init__(self, db: AsyncSession) -> None: self._db = db @@ -51,7 +51,7 @@ class MemoryMiddleware: # ── Public API ──────────────────────────────────────────────────────────── async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]: - """Build memory context dict to inject into the orchestrator before LLM call. + """Build memory context dict to inject into the agent before LLM call. Returns a dict with keys: core_memory — {key: plaintext_value, ...} diff --git a/app/core/orchestrator.py b/app/core/orchestrator.py deleted file mode 100644 index 7765704..0000000 --- a/app/core/orchestrator.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Orchestrator — LLM-based intent router and agent pipeline.""" - -from __future__ import annotations - -import json -from typing import Any, AsyncGenerator - -from langchain_core.messages import HumanMessage, SystemMessage - -from app.core.agent_registry import AgentRegistry, ChatAgent -from app.core.llm import get_router_llm -from app.core.agent_registry import registry as _default_registry -from app.schemas import ChatRequest, ChatResponse, ExecutionPlan - -_FALLBACK_AGENT = "task_agent" - -_CLASSIFY_SYSTEM = ( - "You are an intent classifier. Given the user message and context, decide " - "which agent to route to.\n" - "Available agents: {agents}\n" - "Respond with just the agent name, nothing else." -) - -_SYNTHESIZE_HUMAN = ( - "Combine the following agent results into one coherent response.\n\n" - "Agent results:\n{results}\n\n" - "Original message: {message}" -) - - -def _make_llm(): - return get_router_llm() - - -async def classify_intent( - message: str, - context: dict[str, Any], - reg: AgentRegistry, -) -> str: - """Use gpt-4o-mini to classify intent and return the matching agent name. - - Falls back to ``task_agent`` when the registry is empty or the model - returns a name that is not registered. - """ - agents = reg.list_agents() - if not agents: - return _FALLBACK_AGENT - - system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents)) - # Truncate context to keep the classification prompt short - human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}" - - llm = _make_llm() - response = await llm.ainvoke( - [SystemMessage(content=system), HumanMessage(content=human)] - ) - - agent_name = str(response.content).strip().lower() - known = {a["name"] for a in agents} - return agent_name if agent_name in known else _FALLBACK_AGENT - - -async def route_single( - agent_name: str, - message: str, - context: dict[str, Any], - reg: AgentRegistry, -) -> ChatResponse: - """Route to a single agent and wrap the result in a ``ChatResponse``.""" - response_text = await reg.call_agent(agent_name, message, context) - return ChatResponse(response=response_text) - - -async def route_pipeline( - agent_names: list[str], - message: str, - context: dict[str, Any], - reg: AgentRegistry, -) -> ChatResponse: - """Execute agents sequentially; each agent receives previous results in context. - - A final LLM synthesis call merges all results into one coherent response. - """ - previous_results: list[str] = [] - - for agent_name in agent_names: - ctx = {**context, "previous_results": list(previous_results)} - result = await reg.call_agent(agent_name, message, ctx) - previous_results.append(result) - - results_str = "\n\n".join( - f"[{name}]: {res}" for name, res in zip(agent_names, previous_results) - ) - human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message) - llm = _make_llm() - synthesis = await llm.ainvoke([HumanMessage(content=human)]) - return ChatResponse(response=str(synthesis.content)) - - -def _build_plan(agent_name: str, message: str) -> ExecutionPlan: - """Build an ``ExecutionPlan`` for the resolved agent. - - Uses ``ExecutionPlanBuilder`` with the server-side template registry. - If a default template exists for the agent, an LLM step is emitted; - otherwise a plain ``handle`` action step is used. - """ - from app.core.execution_plan import ExecutionPlanBuilder, template_registry - - template_id = f"tpl_{agent_name}_default" - builder = ExecutionPlanBuilder(agent_name) - if template_registry.has(template_id): - builder.add_llm_step(template_id, {"message": message}) - else: - builder.add_step("handle", {"message": message}) - return builder.build() - - -async def orchestrate( - request: ChatRequest, - reg: AgentRegistry | None = None, -) -> ChatResponse | ExecutionPlan: - """Main orchestration entry point. - - * Classifies the user's intent to select an agent. - * ``execution_mode == 'direct'``: routes to the agent and returns a - ``ChatResponse``. - * ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the - resolved agent and a template-ID-only step (prompt IP stays server-side). - """ - if reg is None: - reg = _default_registry - - context = request.context.model_dump() - agent_name = await classify_intent(request.message, context, reg) - - if request.execution_mode == "direct": - return await route_single(agent_name, request.message, context, reg) - - # plan mode — return plan, do not execute - return _build_plan(agent_name, request.message) - - -async def orchestrate_v3( - user_id: str, - message: str, - context: dict[str, Any], - reg: AgentRegistry | None = None, -) -> tuple[str, ChatAgent]: - """v3 orchestration — returns (agent_name, agent_instance); caller drives execution. - - Classifies intent and instantiates the matching agent. The caller is responsible - for invoking handle(), handle_stream(), or _tool_loop_stream() as needed. - """ - if reg is None: - reg = _default_registry - agent_name = await classify_intent(message, context, reg) - return agent_name, reg.get(agent_name) - - -async def orchestrate_v3_stream( - user_id: str, - message: str, - context: dict[str, Any], - reg: AgentRegistry | None = None, - agent_holder: list | None = None, -) -> AsyncGenerator[tuple[str, str], None]: - """v3 streaming orchestration — yields (agent_name, token) pairs. - - The first yield always carries the agent_name with an empty token so that - callers (e.g. FloatingFormatter) can detect the routing domain before any text - tokens arrive. - - If *agent_holder* is provided (a list), the agent instance is appended so - callers can access ``agent.tool_results`` after the stream completes. - """ - if reg is None: - reg = _default_registry - agent_name = await classify_intent(message, context, reg) - agent = reg.get(agent_name) - if agent_holder is not None: - agent_holder.append(agent) - yield agent_name, "" # domain signal — no token yet - async for token in agent.handle_stream(message, context): - yield agent_name, token - - -async def orchestrate_stream( - request: ChatRequest, - reg: AgentRegistry | None = None, -) -> AsyncGenerator[str, None]: - """Streaming orchestration — yields plain text chunks only. - - The WebSocket handler in ``app/api/routes/chat.py`` is responsible for - wrapping each chunk in a ``text_chunk`` frame and sending the final - ``final`` frame once the generator is exhausted. - - Agents do not yet support token-level streaming; the full response is - fetched first (which may involve multiple WS round-trips for tool calls), - then emitted in fixed-size chunks. - """ - if reg is None: - reg = _default_registry - - context = request.context.model_dump() - agent_name = await classify_intent(request.message, context, reg) - response_text = await reg.call_agent(agent_name, request.message, context) - - chunk_size = 50 - for i in range(0, len(response_text), chunk_size): - yield response_text[i : i + chunk_size] diff --git a/app/core/output_formatter.py b/app/core/output_formatter.py index a8e44fb..a5106e3 100644 --- a/app/core/output_formatter.py +++ b/app/core/output_formatter.py @@ -1,12 +1,23 @@ -"""Output Formatter — transforms orchestrator token streams into WS frame sequences. +"""Output Formatter — transforms deep-agent event streams into WS frame sequences. -HomeFormatter: produces stream_start, stream_text / stream_block, stream_end -FloatingFormatter: produces floating_domain, stream_text, stream_end +Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``: + * ``("token", str)`` — supervisor text token + * ``("tool_end", dict)`` — sub-agent finished: ``{name, result}`` + * ``("mutations", list)`` — collected CRUD mutations for ``stream_end`` + +HomeFormatter: + * Sniffs ``tool_end`` events → emits ``WsStreamBlock`` (entity_ref with raw data) + * Streams text tokens → emits ``WsStreamText`` + * Attaches mutations → injects into ``WsStreamEnd`` + +FloatingFormatter: + * Sniffs first ``tool_end`` name → emits ``WsFloatingDomain`` + * Streams text tokens → emits ``WsStreamText`` + * Attaches mutations → injects into ``WsStreamEnd`` """ from __future__ import annotations -import json import logging from collections.abc import AsyncGenerator from typing import Any @@ -21,10 +32,7 @@ from app.schemas import ( logger = logging.getLogger(__name__) -# Valid chart types (matching shadcn/ui Recharts wrappers in Electron) -_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"} - -# Map agent name → floating domain +# Map sub-agent tool name → floating domain / entity type _AGENT_DOMAIN: dict[str, str] = { "task_agent": "tasks", "timeline_agent": "timelines", @@ -36,180 +44,74 @@ WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatin class HomeFormatter: - """Parses a token stream from orchestrate_v3_stream and yields WS frames. + """Consumes a deep-agent event stream and yields WS frames for the Home view. - The LLM is expected to output a newline-delimited sequence of JSON objects, - each with a ``type`` field: - - ``text`` → yields WsStreamText immediately (word-by-word) - - ``chart`` → buffers full JSON, validates, yields WsStreamBlock - - ``entity_ref`` → resolves from tool_results, yields WsStreamBlock - - ``table`` → buffers full JSON, validates, yields WsStreamBlock - - ``timeline`` → buffers full JSON, validates, yields WsStreamBlock - - Invalid or unknown blocks are logged and skipped — stream never crashes. - """ - - def __init__(self, request_id: str, tool_results: list[dict]) -> None: - self.request_id = request_id - self.tool_results = tool_results - - async def format( - self, - token_stream: AsyncGenerator[tuple[str, str], None], - ) -> AsyncGenerator[WsFrame, None]: - yield WsStreamStart(request_id=self.request_id) - - buffer = "" - async for _agent_name, token in token_stream: - if not token: - continue - buffer += token - # Flush any complete JSON objects from the buffer - async for frame in self._flush_complete_objects(buffer): - buffer = "" # reset after flush - yield frame - break # only one flush per iteration; rest accumulates - - # Flush any remaining content - if buffer.strip(): - async for frame in self._flush_complete_objects(buffer, final=True): - yield frame - - yield WsStreamEnd(request_id=self.request_id) - - async def _flush_complete_objects( - self, text: str, final: bool = False - ) -> AsyncGenerator[WsFrame, None]: - """Try to parse and yield all complete JSON objects from *text*. - - Yields nothing if text is incomplete JSON (unless *final* is True, - in which case remaining text is emitted as plain stream_text). - """ - remaining = text.strip() - while remaining: - # Fast path: plain text (not JSON) - if not remaining.startswith("{"): - # Yield as plain text chunk - newline_idx = remaining.find("\n") - if newline_idx == -1: - if final: - yield WsStreamText(request_id=self.request_id, chunk=remaining) - remaining = "" - else: - return # accumulate more - else: - line = remaining[:newline_idx].strip() - remaining = remaining[newline_idx + 1:].strip() - if line: - yield WsStreamText(request_id=self.request_id, chunk=line) - continue - - # Try to decode a JSON object - try: - obj, end_idx = _try_parse_json(remaining) - except ValueError: - if final: - # Emit as raw text if we can't parse - yield WsStreamText(request_id=self.request_id, chunk=remaining) - remaining = "" - return - - if obj is None: - if final: - yield WsStreamText(request_id=self.request_id, chunk=remaining) - remaining = "" - return # incomplete — need more tokens - - remaining = remaining[end_idx:].strip() - block_type = obj.get("type") - - frame = self._dispatch_block(obj, block_type) - if frame is not None: - yield frame - - def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None: - if block_type == "text": - content = obj.get("content", "") - if content: - return WsStreamText(request_id=self.request_id, chunk=str(content)) - return None - - if block_type == "chart": - chart_type = obj.get("chartType") - if chart_type not in _VALID_CHART_TYPES: - logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type) - return None - if not isinstance(obj.get("data"), list): - logger.warning("HomeFormatter: chart missing data array — skipping") - return None - return WsStreamBlock( - request_id=self.request_id, - block_type="chart", - data=obj, - ) - - if block_type == "entity_ref": - entity = obj.get("entity") - resolved = self._resolve_entity(entity) - if resolved is None: - logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity) - return None - return WsStreamBlock( - request_id=self.request_id, - block_type="entity_ref", - data={"entity": entity, "items": resolved}, - ) - - if block_type == "table": - if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list): - logger.warning("HomeFormatter: table missing headers/rows — skipping") - return None - return WsStreamBlock( - request_id=self.request_id, - block_type="table", - data=obj, - ) - - if block_type == "timeline": - if not isinstance(obj.get("timelines"), list): - logger.warning("HomeFormatter: timeline missing timelines — skipping") - return None - return WsStreamBlock( - request_id=self.request_id, - block_type="timeline", - data=obj, - ) - - logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type) - return None - - def _resolve_entity(self, entity: str | None) -> list[dict] | None: - """Find matching items in tool_results by entity type.""" - if not entity: - return None - matches = [r for r in self.tool_results if r.get("entity") == entity] - return matches if matches else None - - -class FloatingFormatter: - """Parses a token stream from orchestrate_v3_stream and yields WS frames. - - Emits floating_domain immediately (from agent_name), then streams all tokens - as plain stream_text — no block parsing for floating context. + ``tool_end`` events from sub-agents are emitted as ``WsStreamBlock`` + (entity_ref) so the client can render structured data. Text tokens are + forwarded as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``. """ def __init__(self, request_id: str) -> None: self.request_id = request_id + self._mutations: list[dict] = [] async def format( self, - token_stream: AsyncGenerator[tuple[str, str], None], + event_stream: AsyncGenerator[tuple[str, Any], None], + ) -> AsyncGenerator[WsFrame, None]: + yield WsStreamStart(request_id=self.request_id) + + async for event_type, data in event_stream: + if event_type == "token": + if data: + yield WsStreamText(request_id=self.request_id, chunk=data) + + elif event_type == "tool_end": + # Sub-agent finished — emit its result as an entity_ref block + name = data.get("name", "") + entity = _AGENT_DOMAIN.get(name) + if entity: + yield WsStreamBlock( + request_id=self.request_id, + block_type="entity_ref", + data={"entity": entity, "result": data.get("result", "")}, + ) + + elif event_type == "mutations": + self._mutations = data or [] + + yield WsStreamEnd( + request_id=self.request_id, + mutations=[ + {"action": m["action"], "table": m["table"], "data": m["data"]} + for m in self._mutations + ], + ) + + +class FloatingFormatter: + """Consumes a deep-agent event stream and yields WS frames for the Floating view. + + Sniffs the first ``tool_end`` event name to derive the domain (e.g. + ``task_agent`` → ``"tasks"``), then streams text tokens as plain + ``WsStreamText``. No block parsing for floating context. + """ + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._mutations: list[dict] = [] + + async def format( + self, + event_stream: AsyncGenerator[tuple[str, Any], None], ) -> AsyncGenerator[WsFrame, None]: domain_sent = False - async for agent_name, token in token_stream: - if not domain_sent: - domain = _AGENT_DOMAIN.get(agent_name, "tasks") + async for event_type, data in event_stream: + if event_type == "tool_end" and not domain_sent: + # Sniff domain from the first sub-agent that completes + name = data.get("name", "") + domain = _AGENT_DOMAIN.get(name, "tasks") yield WsFloatingDomain( request_id=self.request_id, domain=domain, # type: ignore[arg-type] @@ -217,28 +119,33 @@ class FloatingFormatter: yield WsStreamStart(request_id=self.request_id) domain_sent = True - if token: - yield WsStreamText(request_id=self.request_id, chunk=token) + elif event_type == "token": + if not domain_sent: + # First token arrived before any tool_end — default domain + yield WsFloatingDomain( + request_id=self.request_id, + domain="tasks", # type: ignore[arg-type] + ) + yield WsStreamStart(request_id=self.request_id) + domain_sent = True + if data: + yield WsStreamText(request_id=self.request_id, chunk=data) - yield WsStreamEnd(request_id=self.request_id) + elif event_type == "mutations": + self._mutations = data or [] + # If no events triggered domain_sent (edge case), still emit structure + if not domain_sent: + yield WsFloatingDomain( + request_id=self.request_id, + domain="tasks", # type: ignore[arg-type] + ) + yield WsStreamStart(request_id=self.request_id) -# ── helpers ─────────────────────────────────────────────────────────────────── - -def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]: - """Attempt to parse the first complete JSON object from *text*. - - Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the - object is incomplete, and raises ``ValueError`` when text is not JSON. - """ - decoder = json.JSONDecoder() - try: - obj, end_idx = decoder.raw_decode(text) - if not isinstance(obj, dict): - raise ValueError("Expected JSON object") - return obj, end_idx - except json.JSONDecodeError as exc: - # Incomplete JSON — need more tokens - if "Unterminated" in str(exc) or exc.pos == len(text): - return None, 0 - raise ValueError(str(exc)) from exc + yield WsStreamEnd( + request_id=self.request_id, + mutations=[ + {"action": m["action"], "table": m["table"], "data": m["data"]} + for m in self._mutations + ], + ) diff --git a/app/core/ws_context.py b/app/core/ws_context.py index 14ac879..1dd6eec 100644 --- a/app/core/ws_context.py +++ b/app/core/ws_context.py @@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`. from __future__ import annotations +import logging from contextvars import ContextVar from typing import Any, Callable, Coroutine from uuid import uuid4 +logger = logging.getLogger(__name__) + # Holds the execute callback for the current WS session. -# Set by the chat WS handler before the orchestrator runs; cleared after. +# Set by the chat WS handler before the deep agent runs; cleared after. _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar( "_client_executor" ) # Optional collector that captures raw execute_on_client results. -# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results. +# Set by the deep agent tool loop to capture CRUD mutations. _tool_result_collector: ContextVar[list[dict] | None] = ContextVar( "_tool_result_collector", default=None ) @@ -81,7 +84,12 @@ async def execute_on_client( if limit is not None: payload["limit"] = limit + logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"]) result = await callback(payload) + if result is None: + logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"]) + else: + logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A") collector = _tool_result_collector.get(None) if collector is not None: collector.append({ diff --git a/app/main.py b/app/main.py index 74c25ee..50aebf6 100644 --- a/app/main.py +++ b/app/main.py @@ -18,10 +18,7 @@ from app.config.settings import settings @asynccontextmanager async def lifespan(app: FastAPI): - # Startup: initialise DB connection pool and agent registry - from app.core.agent_registry import registry # noqa: F401 — triggers module load - import app.agents # noqa: F401 — triggers @registry.register decorators - + # Startup: initialise DB connection pool yield # Shutdown: dispose SQLAlchemy connection pool @@ -51,11 +48,10 @@ def create_app() -> FastAPI: app.add_middleware(SanitizerMiddleware) app.add_middleware(TierRateLimitMiddleware) - from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors + from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors app.include_router(auth.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1") - app.include_router(plans.router, prefix="/api/v1") app.include_router(storage.router, prefix="/api/v1") app.include_router(vectors.router, prefix="/api/v1") app.include_router(backup.router, prefix="/api/v1") diff --git a/app/schemas.py b/app/schemas.py index f3a281b..69a8117 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -41,41 +41,13 @@ class ChatContext(BaseModel): conversation_history: list[dict[str, Any]] = Field(default_factory=list) -class PlanAction(BaseModel): - type: Literal[ - "create_record", - "update_record", - "delete_record", - "index_document", - "send_notification", - ] - table: str | None = None - data: dict[str, Any] | None = None - - class ChatRequest(BaseModel): message: str context: ChatContext = Field(default_factory=ChatContext) - execution_mode: Literal["direct", "plan"] = "direct" class ChatResponse(BaseModel): response: str - actions: list[PlanAction] = Field(default_factory=list) - - -# ── Execution Plans ────────────────────────────────────────────────── - -class PlanStep(BaseModel): - action: str - prompt_template: str | None = None - variables: dict[str, Any] | None = None - data_from_step: int | None = None - - -class ExecutionPlan(BaseModel): - agent: str - steps: list[PlanStep] = Field(default_factory=list) # ── Backup ─────────────────────────────────────────────────────────── diff --git a/requirements.txt b/requirements.txt index ea10f59..a3b88ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ gunicorn>=22.0.0 langchain>=0.3.0 langchain-openai>=0.3.0 langchain-litellm>=0.1.0 +langgraph>=0.3.0 litellm>=1.50.0 pydantic>=2.10.0 pydantic-settings>=2.7.0 diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py deleted file mode 100644 index 9fd9381..0000000 --- a/tests/test_agent_registry.py +++ /dev/null @@ -1,214 +0,0 @@ -"""Unit tests for the agent registry, base classes, and tool loop.""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from app.core.agent_registry import AgentRegistry, ChatAgent - - -# ── Helpers ────────────────────────────────────────────────────────── - -class _StubAgent(ChatAgent): - """Minimal concrete agent for testing.""" - - def get_name(self) -> str: - return "stub" - - def get_description(self) -> str: - return "A stub agent for tests" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return f"echo: {query}" - - -class _AnotherAgent(ChatAgent): - def get_name(self) -> str: - return "another" - - def get_description(self) -> str: - return "Another stub" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return "another" - - -# ── Fixtures ───────────────────────────────────────────────────────── - -@pytest.fixture(autouse=True) -def _fresh_registry(): - """Reset the singleton between tests.""" - AgentRegistry._instance = None - yield - AgentRegistry._instance = None - - -@pytest.fixture() -def reg() -> AgentRegistry: - return AgentRegistry() - - -# ── Tests ──────────────────────────────────────────────────────────── - -class TestRegisterAndGet: - def test_register_decorator(self, reg: AgentRegistry) -> None: - reg.register(_StubAgent) - agent = reg.get("stub") - assert isinstance(agent, _StubAgent) - - def test_get_unknown_raises(self, reg: AgentRegistry) -> None: - with pytest.raises(KeyError, match="not found"): - reg.get("nonexistent") - - def test_register_multiple(self, reg: AgentRegistry) -> None: - reg.register(_StubAgent) - reg.register(_AnotherAgent) - assert reg.get("stub").get_name() == "stub" - assert reg.get("another").get_name() == "another" - - -class TestListAgents: - def test_empty(self, reg: AgentRegistry) -> None: - assert reg.list_agents() == [] - - def test_list_after_register(self, reg: AgentRegistry) -> None: - reg.register(_StubAgent) - agents = reg.list_agents() - assert len(agents) == 1 - assert agents[0] == {"name": "stub", "description": "A stub agent for tests"} - - def test_list_multiple(self, reg: AgentRegistry) -> None: - reg.register(_StubAgent) - reg.register(_AnotherAgent) - names = {a["name"] for a in reg.list_agents()} - assert names == {"stub", "another"} - - -class TestCallAgent: - @pytest.mark.asyncio - async def test_call_agent(self, reg: AgentRegistry) -> None: - reg.register(_StubAgent) - result = await reg.call_agent("stub", "hello", {}) - assert result == "echo: hello" - - @pytest.mark.asyncio - async def test_call_unknown_raises(self, reg: AgentRegistry) -> None: - with pytest.raises(KeyError): - await reg.call_agent("nope", "hi", {}) - - -class TestSingleton: - def test_singleton_identity(self) -> None: - a = AgentRegistry() - b = AgentRegistry() - assert a is b - - -class TestToolLoop: - @pytest.mark.asyncio - async def test_no_tool_calls(self) -> None: - """When the LLM responds without tool calls, return content directly.""" - agent = _StubAgent() - - ai_msg = MagicMock() - ai_msg.content = "final answer" - ai_msg.tool_calls = [] - - llm = AsyncMock() - llm.bind_tools = MagicMock(return_value=llm) - llm.ainvoke = AsyncMock(return_value=ai_msg) - - result = await agent._tool_loop(llm, [], []) - assert result == "final answer" - - @pytest.mark.asyncio - async def test_tool_call_then_answer(self) -> None: - """LLM requests one tool call, gets result, then answers.""" - agent = _StubAgent() - - # First response: tool call - tool_call_msg = MagicMock() - tool_call_msg.content = "" - tool_call_msg.tool_calls = [ - {"id": "call_1", "name": "my_tool", "args": {"x": 1}} - ] - - # Second response: final answer - final_msg = MagicMock() - final_msg.content = "done" - final_msg.tool_calls = [] - - llm = AsyncMock() - llm.bind_tools = MagicMock(return_value=llm) - llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) - - # Mock tool - tool = AsyncMock() - tool.name = "my_tool" - tool.ainvoke = AsyncMock(return_value="tool_result") - - result = await agent._tool_loop(llm, [], [tool]) - assert result == "done" - tool.ainvoke.assert_called_once_with({"x": 1}) - - @pytest.mark.asyncio - async def test_unknown_tool_handled(self) -> None: - """Unknown tool names produce an error message instead of crashing.""" - agent = _StubAgent() - - tool_call_msg = MagicMock() - tool_call_msg.content = "" - tool_call_msg.tool_calls = [ - {"id": "call_1", "name": "missing", "args": {}} - ] - - final_msg = MagicMock() - final_msg.content = "recovered" - final_msg.tool_calls = [] - - llm = AsyncMock() - llm.bind_tools = MagicMock(return_value=llm) - llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) - - result = await agent._tool_loop(llm, [], []) - assert result == "recovered" - - @pytest.mark.asyncio - async def test_max_iter_reached(self) -> None: - """When max iterations are exhausted, a final no-tools call is made.""" - agent = _StubAgent() - - # Every response requests a tool call - loop_msg = MagicMock() - loop_msg.content = "" - loop_msg.tool_calls = [ - {"id": "call_x", "name": "t", "args": {}} - ] - - final_msg = MagicMock() - final_msg.content = "gave up" - final_msg.tool_calls = [] - - tool = AsyncMock() - tool.name = "t" - tool.ainvoke = AsyncMock(return_value="ok") - - llm_with_tools = AsyncMock() - llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg) - - llm = AsyncMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - llm.ainvoke = AsyncMock(return_value=final_msg) - - result = await agent._tool_loop(llm, [], [tool], max_iter=2) - assert result == "gave up" - assert llm_with_tools.ainvoke.call_count == 2 diff --git a/tests/test_agent_streaming.py b/tests/test_agent_streaming.py deleted file mode 100644 index 59a8232..0000000 --- a/tests/test_agent_streaming.py +++ /dev/null @@ -1,416 +0,0 @@ -"""Tests for ChatAgent streaming and tool result capture (Step 2).""" - -from __future__ import annotations - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from typing import Any - -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage - -from app.core.agent_registry import ChatAgent, registry - - -# ── Minimal concrete agent for testing ─────────────────────────────── - - -class _EchoAgent(ChatAgent): - def get_name(self) -> str: - return "_echo" - - def get_description(self) -> str: - return "Echo agent for tests" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return query - - -# ── Helpers ─────────────────────────────────────────────────────────── - - -def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage: - msg = AIMessage(content=content) - if tool_calls: - msg.tool_calls = tool_calls - else: - msg.tool_calls = [] - return msg - - -def _make_tool(name: str, return_value: Any) -> MagicMock: - t = MagicMock() - t.name = name - t.ainvoke = AsyncMock(return_value=return_value) - return t - - -def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]: - chunks = [] - for tok in tokens: - c = MagicMock() - c.content = tok - chunks.append(c) - return chunks - - -async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]: - tokens: list[str] = [] - async for tok in agent._tool_loop_stream(llm, messages, tools): - tokens.append(tok) - return tokens - - -# ── tool_results initialised ───────────────────────────────────────── - - -def test_tool_results_init(): - agent = _EchoAgent() - assert agent.tool_results == [] - - -# ── _tool_loop: no tool calls ──────────────────────────────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_no_tools(): - agent = _EchoAgent() - llm = AsyncMock() - llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!")) - - result = await agent._tool_loop(llm, [HumanMessage(content="hi")], []) - assert result == "Hello!" - assert agent.tool_results == [] - - -# ── _tool_loop: with one tool call + result capture ────────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_captures_tool_results(): - agent = _EchoAgent() - - # Mock execute_on_client to return structured data via the tool - raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]} - - async def fake_executor(payload: dict) -> dict: - return raw_result - - # AIMessage with a tool call, then a final answer - tool_call_msg = _make_ai_message( - tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}] - ) - final_msg = _make_ai_message("Here are your tasks.") - - llm = MagicMock() - llm_with_tools = MagicMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) - llm.ainvoke = AsyncMock(return_value=final_msg) - - mock_tool = _make_tool("list_tasks", "- Fix bug (todo)") - - from app.core.ws_context import set_client_executor, clear_client_executor - set_client_executor(fake_executor) - try: - # Patch the tool to actually call execute_on_client - async def tool_side_effect(args: dict) -> str: - from app.core.ws_context import execute_on_client - res = await execute_on_client(action="select", table="tasks") - rows = res.get("rows", []) - return "\n".join(r["title"] for r in rows) - - mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect) - - result = await agent._tool_loop( - llm, [HumanMessage(content="list my tasks")], [mock_tool] - ) - finally: - clear_client_executor() - - assert result == "Here are your tasks." - assert len(agent.tool_results) == 1 - assert agent.tool_results[0] == raw_result - - -# ── _tool_loop: tool_results reset on each call ────────────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_resets_tool_results(): - agent = _EchoAgent() - agent.tool_results = [{"stale": True}] # pre-populated from a previous call - - llm = AsyncMock() - llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done.")) - - await agent._tool_loop(llm, [HumanMessage(content="hi")], []) - assert agent.tool_results == [] - - -# ── _tool_loop: unknown tool name ──────────────────────────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_unknown_tool(): - agent = _EchoAgent() - - # No known tools — model still calls a non-existent one; loop handles gracefully - tool_call_msg = _make_ai_message( - tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}] - ) - final_msg = _make_ai_message("Handled.") - - mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent" - llm = MagicMock() - llm_with_tools = MagicMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) - - result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool]) - assert result == "Handled." - - -# ── _tool_loop: max_iter exhaustion ────────────────────────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_max_iter(): - agent = _EchoAgent() - - always_tool = _make_ai_message( - tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}] - ) - fallback = _make_ai_message("Fallback.") - - llm = MagicMock() - llm_with_tools = MagicMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - # Returns tool_call_msg on every iteration - llm_with_tools.ainvoke = AsyncMock(return_value=always_tool) - llm.ainvoke = AsyncMock(return_value=fallback) - - mock_tool = _make_tool("t", "ok") - - result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2) - assert result == "Fallback." - assert llm_with_tools.ainvoke.call_count == 2 - - -# ── _tool_loop_stream: no tool calls — yields tokens ───────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_stream_no_tools_yields_tokens(): - agent = _EchoAgent() - - # No tools → llm used directly; ainvoke returns no tool calls → stream is used - no_tool_msg = _make_ai_message("irrelevant") - llm = AsyncMock() - llm.ainvoke = AsyncMock(return_value=no_tool_msg) - - async def fake_astream(msgs): - for tok in ["Hello", " ", "world"]: - c = MagicMock() - c.content = tok - yield c - - llm.astream = fake_astream - - tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], []) - assert tokens == ["Hello", " ", "world"] - assert agent.tool_results == [] - - -# ── _tool_loop_stream: one tool call then streaming final ───────────── - - -@pytest.mark.asyncio -async def test_tool_loop_stream_with_tool_call(): - agent = _EchoAgent() - - raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}} - - async def fake_executor(payload: dict) -> dict: - return raw_result - - tool_call_msg = _make_ai_message( - tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}] - ) - # After tools run, ainvoke returns no more tool calls - no_more_tools_msg = _make_ai_message("Task found.") - - llm = MagicMock() - llm_with_tools = MagicMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg]) - - async def fake_astream(msgs): - for tok in ["Task", " ", "found."]: - c = MagicMock() - c.content = tok - yield c - - llm.astream = fake_astream - - async def tool_side_effect(args: dict) -> str: - from app.core.ws_context import execute_on_client - res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")}) - return res.get("row", {}).get("title", "") - - mock_tool = _make_tool("get_task", "Deploy") - mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect) - - from app.core.ws_context import set_client_executor, clear_client_executor - set_client_executor(fake_executor) - try: - tokens = await _collect_stream( - agent, llm, [HumanMessage(content="get task t-2")], [mock_tool] - ) - finally: - clear_client_executor() - - assert tokens == ["Task", " ", "found."] - assert len(agent.tool_results) == 1 - assert agent.tool_results[0] == raw_result - - -# ── _tool_loop_stream: tool_results reset on each call ─────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_stream_resets_tool_results(): - agent = _EchoAgent() - agent.tool_results = [{"old": True}] - - no_tool_msg = _make_ai_message("") - llm = AsyncMock() - llm.ainvoke = AsyncMock(return_value=no_tool_msg) - - async def fake_astream(msgs): - c = MagicMock() - c.content = "ok" - yield c - - llm.astream = fake_astream - - await _collect_stream(agent, llm, [HumanMessage(content="x")], []) - assert agent.tool_results == [] - - -# ── _tool_loop_stream: empty chunk content is skipped ──────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_stream_skips_empty_chunks(): - agent = _EchoAgent() - no_tool_msg = _make_ai_message("") - - llm = AsyncMock() - llm.ainvoke = AsyncMock(return_value=no_tool_msg) - - async def fake_astream(msgs): - for tok in ["", "hello", "", " world", ""]: - c = MagicMock() - c.content = tok - yield c - - llm.astream = fake_astream - - tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], []) - assert tokens == ["hello", " world"] - - -# ── _tool_loop_stream: max_iter exhaustion falls back to stream ─────── - - -@pytest.mark.asyncio -async def test_tool_loop_stream_max_iter(): - agent = _EchoAgent() - - always_tool = _make_ai_message( - tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}] - ) - - llm = MagicMock() - llm_with_tools = MagicMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - llm_with_tools.ainvoke = AsyncMock(return_value=always_tool) - - async def fake_astream(msgs): - c = MagicMock() - c.content = "fallback" - yield c - - llm.astream = fake_astream - mock_tool = _make_tool("t", "ok") - - tokens = await _collect_stream( - agent, llm, [HumanMessage(content="x")], [mock_tool], - ) - assert tokens == ["fallback"] - assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter - - -# ── _tool_loop_stream: multiple tool results captured ──────────────── - - -@pytest.mark.asyncio -async def test_tool_loop_stream_multiple_tool_results(): - agent = _EchoAgent() - - call_results = [ - {"rows": [{"id": "t-1"}]}, - {"rows": [{"id": "t-2"}]}, - ] - call_iter = iter(call_results) - - async def fake_executor(payload: dict) -> dict: - return next(call_iter) - - # Two tool calls in one iteration - tool_call_msg = _make_ai_message( - tool_calls=[ - {"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"}, - {"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"}, - ] - ) - no_more_tools_msg = _make_ai_message("Done.") - - llm = MagicMock() - llm_with_tools = MagicMock() - llm.bind_tools = MagicMock(return_value=llm_with_tools) - llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg]) - - async def fake_astream(msgs): - c = MagicMock() - c.content = "Done." - yield c - - llm.astream = fake_astream - - async def tool_side_effect(args: dict) -> str: - from app.core.ws_context import execute_on_client - res = await execute_on_client(action="select", table="tasks") - return str(res) - - tool_a = _make_tool("tool_a", "") - tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect) - tool_b = _make_tool("tool_b", "") - tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect) - - from app.core.ws_context import set_client_executor, clear_client_executor - set_client_executor(fake_executor) - try: - tokens = await _collect_stream( - agent, llm, [HumanMessage(content="x")], [tool_a, tool_b] - ) - finally: - clear_client_executor() - - assert tokens == ["Done."] - assert len(agent.tool_results) == 2 - assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]} - assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]} diff --git a/tests/test_agents.py b/tests/test_agents.py deleted file mode 100644 index 4023232..0000000 --- a/tests/test_agents.py +++ /dev/null @@ -1,761 +0,0 @@ -"""Unit tests for the four domain-specific chat agents with mocked LLM.""" - -from __future__ import annotations - -import json -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -import app.agents # noqa: F401 — triggers @registry.register decorators -from app.agents.timeline_agent import TimelineAgent -from app.agents.note_agent import NoteAgent -from app.agents.project_agent import ProjectAgent -from app.agents.task_agent import TaskAgent -from app.core.agent_registry import registry -from app.core.ws_context import clear_client_executor, set_client_executor - - -# ── WS executor mock ────────────────────────────────────────────────── -# -# Tools call execute_on_client() which reads a ContextVar set by the WS -# handler. In unit tests there is no WS session, so we install a fake -# executor that returns plausible data for each action type. - -_FAKE_ROW: dict[str, Any] = { - "id": "fake-id", - "title": "Fake Title", - "name": "Fake Name", - "status": "todo", - "priority": "medium", - "content": "Fake content", - "date": 1700000000000, - "taskId": "fake-task-id", - "author": "Alice", - "projectId": None, -} - - -async def _fake_executor(payload: dict) -> dict: - action = payload.get("action", "") - if action == "select": - return {"rows": []} - if action == "insert": - data = payload.get("data", {}) - return {"row": {**_FAKE_ROW, **data}} - if action == "update": - data = payload.get("data", {}) - row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})} - return {"row": row} - if action == "delete": - return {"deleted": True} - if action == "get": - data = payload.get("data", {}) - return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}} - if action == "vector_upsert": - return {"ok": True} - return {} - - -@pytest.fixture(autouse=True) -def ws_executor(): - """Install a fake WS executor for every test so tools can run without a real WS.""" - set_client_executor(_fake_executor) - yield - clear_client_executor() - - -# ── Helpers ────────────────────────────────────────────────────────── - - -def _mock_llm(response_text: str) -> MagicMock: - """Return a mock LLM that responds with *response_text* (no tool calls).""" - msg = MagicMock() - msg.content = response_text - msg.tool_calls = [] - llm = MagicMock() - bound = MagicMock() - bound.ainvoke = AsyncMock(return_value=msg) - llm.bind_tools = MagicMock(return_value=bound) - llm.ainvoke = AsyncMock(return_value=msg) - return llm - - -def _mock_llm_with_tool_call( - tool_name: str, tool_args: dict[str, Any], final_text: str -) -> MagicMock: - """Mock LLM that fires one tool call then returns *final_text*.""" - tool_msg = MagicMock() - tool_msg.content = "" - tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}] - - final_msg = MagicMock() - final_msg.content = final_text - final_msg.tool_calls = [] - - bound = MagicMock() - bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg]) - - llm = MagicMock() - llm.bind_tools = MagicMock(return_value=bound) - llm.ainvoke = AsyncMock(return_value=final_msg) - return llm - - -# ── Registration ────────────────────────────────────────────────────── - - -class TestAgentRegistration: - def test_all_agents_registered(self) -> None: - names = {a["name"] for a in registry.list_agents()} - assert { - "task_agent", "timeline_agent", "project_agent", "note_agent" - }.issubset(names) - - def test_registry_returns_correct_types(self) -> None: - assert isinstance(registry.get("task_agent"), TaskAgent) - assert isinstance(registry.get("timeline_agent"), TimelineAgent) - assert isinstance(registry.get("project_agent"), ProjectAgent) - assert isinstance(registry.get("note_agent"), NoteAgent) - - def test_descriptions_present(self) -> None: - for agent_info in registry.list_agents(): - assert agent_info["description"], f"Empty description: {agent_info['name']}" - - -# ── TaskAgent ───────────────────────────────────────────────────────── - - -class TestTaskAgent: - def test_name(self) -> None: - assert TaskAgent().get_name() == "task_agent" - - def test_description(self) -> None: - assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments" - - def test_get_tools_count(self) -> None: - assert len(TaskAgent().get_tools()) == 8 - - def test_tool_names(self) -> None: - names = {t.name for t in TaskAgent().get_tools()} - assert names == { - "list_tasks", - "create_task", - "update_task", - "delete_task", - "list_tasks_due_today", - "list_task_comments", - "add_task_comment", - "delete_task_comment", - } - - @pytest.mark.asyncio - async def test_handle_returns_string(self) -> None: - with patch("app.agents.task_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Task created.") - result = await TaskAgent().handle("create a task", {}) - assert isinstance(result, str) - - @pytest.mark.asyncio - async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.task_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Here are your tasks.") - result = await TaskAgent().handle("list my tasks", {}) - assert result == "Here are your tasks." - - @pytest.mark.asyncio - async def test_handle_with_create_task_tool_call(self) -> None: - with patch("app.agents.task_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm_with_tool_call( - "create_task", - {"title": "Buy groceries", "priority": "low"}, - "Task 'Buy groceries' created.", - ) - result = await TaskAgent().handle("add a grocery task", {}) - assert result == "Task 'Buy groceries' created." - - @pytest.mark.asyncio - async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.task_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Done.") - result = await TaskAgent().handle("help", {}) - assert isinstance(result, str) - - @pytest.mark.asyncio - async def test_handle_accepts_rich_context(self) -> None: - context = { - "user_profile": {"id": "u1", "tier": "pro"}, - "recent_tasks": [{"id": "t1", "title": "Old task"}], - } - with patch("app.agents.task_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Tasks listed.") - result = await TaskAgent().handle("show tasks", context) - assert isinstance(result, str) - - -class TestTaskAgentTools: - @pytest.mark.asyncio - async def test_list_tasks_defaults(self) -> None: - from app.agents.task_agent import list_tasks - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_tasks.ainvoke({}) - m.assert_called_once_with( - action="select", table="tasks", - filters={"projectId": None, "status": None, "search": None, "orderBy": None}, - ) - assert result == "No tasks found matching the given filters." - - @pytest.mark.asyncio - async def test_list_tasks_with_status_filter(self) -> None: - from app.agents.task_agent import list_tasks - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - await list_tasks.ainvoke({"status": "done"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["filters"]["status"] == "done" - - @pytest.mark.asyncio - async def test_create_task_defaults(self) -> None: - from app.agents.task_agent import create_task - fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"} - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await create_task.ainvoke({"title": "Test task"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "insert" - assert call_kwargs["table"] == "tasks" - assert call_kwargs["data"]["title"] == "Test task" - assert call_kwargs["data"]["status"] == "todo" - assert call_kwargs["data"]["priority"] == "medium" - assert "Test task" in result - - @pytest.mark.asyncio - async def test_create_task_with_all_fields(self) -> None: - from app.agents.task_agent import create_task - fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"} - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await create_task.ainvoke({ - "title": "Deploy", "priority": "high", "status": "in_progress", - "project_id": "p1", "is_ai_suggested": 1, - }) - call_kwargs = m.call_args.kwargs - assert call_kwargs["data"]["priority"] == "high" - assert call_kwargs["data"]["status"] == "in_progress" - assert call_kwargs["data"]["projectId"] == "p1" - assert call_kwargs["data"]["isAiSuggested"] == 1 - - @pytest.mark.asyncio - async def test_update_task_with_status(self) -> None: - from app.agents.task_agent import update_task - fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"} - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await update_task.ainvoke({"task_id": "t1", "status": "done"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "update" - assert call_kwargs["data"]["id"] == "t1" - assert call_kwargs["data"]["updates"]["status"] == "done" - assert "t1" in result - - @pytest.mark.asyncio - async def test_update_task_empty_updates(self) -> None: - from app.agents.task_agent import update_task - fake_row = {"id": "t1", "title": "Task", "status": "todo"} - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await update_task.ainvoke({"task_id": "t1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["data"]["updates"] == {} - - @pytest.mark.asyncio - async def test_delete_task(self) -> None: - from app.agents.task_agent import delete_task - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"deleted": True} - result = await delete_task.ainvoke({"task_id": "t1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "delete" - assert call_kwargs["table"] == "tasks" - assert call_kwargs["data"]["id"] == "t1" - assert "t1" in result - - @pytest.mark.asyncio - async def test_list_tasks_due_today(self) -> None: - from app.agents.task_agent import list_tasks_due_today - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_tasks_due_today.ainvoke({}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "select" - assert call_kwargs["table"] == "tasks" - assert "dueDateFrom" in call_kwargs["filters"] - assert result == "No tasks are due today." - - @pytest.mark.asyncio - async def test_list_task_comments(self) -> None: - from app.agents.task_agent import list_task_comments - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_task_comments.ainvoke({"task_id": "t1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "select" - assert call_kwargs["table"] == "taskComments" - assert call_kwargs["filters"]["taskId"] == "t1" - assert "t1" in result - - @pytest.mark.asyncio - async def test_add_task_comment(self) -> None: - from app.agents.task_agent import add_task_comment - fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"} - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await add_task_comment.ainvoke({ - "task_id": "t1", "author": "Alice", "content": "Looks good!", - }) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "insert" - assert call_kwargs["table"] == "taskComments" - assert call_kwargs["data"]["taskId"] == "t1" - assert call_kwargs["data"]["author"] == "Alice" - assert call_kwargs["data"]["content"] == "Looks good!" - assert "Alice" in result - - @pytest.mark.asyncio - async def test_delete_task_comment(self) -> None: - from app.agents.task_agent import delete_task_comment - with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"deleted": True} - result = await delete_task_comment.ainvoke({"comment_id": "c1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "delete" - assert call_kwargs["table"] == "taskComments" - assert call_kwargs["data"]["id"] == "c1" - assert "c1" in result - - -# ── TimelineAgent ─────────────────────────────────────────────────── - - -class TestTimelineAgent: - def test_name(self) -> None: - assert TimelineAgent().get_name() == "timeline_agent" - - def test_description(self) -> None: - assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete" - - def test_get_tools_count(self) -> None: - assert len(TimelineAgent().get_tools()) == 4 - - def test_tool_names(self) -> None: - names = {t.name for t in TimelineAgent().get_tools()} - assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"} - - @pytest.mark.asyncio - async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.timeline_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("No timelines found.") - result = await TimelineAgent().handle("list timelines", {}) - assert result == "No timelines found." - - @pytest.mark.asyncio - async def test_handle_with_create_tool_call(self) -> None: - with patch("app.agents.timeline_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm_with_tool_call( - "create_timeline", - {"project_id": "p1", "title": "MVP Launch", "date": 1700000000000}, - "Timeline 'MVP Launch' created.", - ) - result = await TimelineAgent().handle("add MVP timeline", {}) - assert result == "Timeline 'MVP Launch' created." - - @pytest.mark.asyncio - async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.timeline_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Done.") - result = await TimelineAgent().handle("show milestones", {}) - assert isinstance(result, str) - - -class TestTimelineAgentTools: - @pytest.mark.asyncio - async def test_list_timelines_no_project(self) -> None: - from app.agents.timeline_agent import list_timelines - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_timelines.ainvoke({}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "select" - assert call_kwargs["table"] == "timelines" - assert call_kwargs["filters"]["projectId"] is None - assert result == "No timelines found." - - @pytest.mark.asyncio - async def test_list_timelines_with_project(self) -> None: - from app.agents.timeline_agent import list_timelines - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - await list_timelines.ainvoke({"project_id": "p1"}) - assert m.call_args.kwargs["filters"]["projectId"] == "p1" - - @pytest.mark.asyncio - async def test_create_timeline(self) -> None: - from app.agents.timeline_agent import create_timeline - fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000} - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await create_timeline.ainvoke({ - "project_id": "p1", "title": "Beta release", "date": 1700000000000, - }) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "insert" - assert call_kwargs["table"] == "timelines" - assert call_kwargs["data"]["projectId"] == "p1" - assert call_kwargs["data"]["title"] == "Beta release" - assert call_kwargs["data"]["date"] == 1700000000000 - assert "Beta release" in result - - @pytest.mark.asyncio - async def test_create_timeline_ai_suggested(self) -> None: - from app.agents.timeline_agent import create_timeline - fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000} - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await create_timeline.ainvoke({ - "project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1, - }) - call_kwargs = m.call_args.kwargs - assert call_kwargs["data"]["isAiSuggested"] == 1 - assert call_kwargs["data"]["isApproved"] == 0 - - @pytest.mark.asyncio - async def test_update_timeline_approve(self) -> None: - from app.agents.timeline_agent import update_timeline - fake_row = {"id": "c1", "title": "MVP", "isApproved": 1} - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "update" - assert call_kwargs["data"]["id"] == "c1" - assert call_kwargs["data"]["updates"]["isApproved"] == 1 - assert "c1" in result - - @pytest.mark.asyncio - async def test_update_timeline_empty_updates(self) -> None: - from app.agents.timeline_agent import update_timeline - fake_row = {"id": "c1", "title": "MVP"} - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await update_timeline.ainvoke({"timeline_id": "c1"}) - assert m.call_args.kwargs["data"]["updates"] == {} - - @pytest.mark.asyncio - async def test_delete_timeline(self) -> None: - from app.agents.timeline_agent import delete_timeline - with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"deleted": True} - result = await delete_timeline.ainvoke({"timeline_id": "c1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "delete" - assert call_kwargs["table"] == "timelines" - assert call_kwargs["data"]["id"] == "c1" - assert "c1" in result - - -# ── ProjectAgent ────────────────────────────────────────────────────── - - -class TestProjectAgent: - def test_name(self) -> None: - assert ProjectAgent().get_name() == "project_agent" - - def test_description(self) -> None: - assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete" - - def test_get_tools_count(self) -> None: - assert len(ProjectAgent().get_tools()) == 6 - - def test_tool_names(self) -> None: - names = {t.name for t in ProjectAgent().get_tools()} - assert names == { - "list_projects", - "list_all_projects", - "get_project", - "create_project", - "update_project", - "delete_project", - } - - @pytest.mark.asyncio - async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.project_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Project Alpha is active.") - result = await ProjectAgent().handle("show my projects", {}) - assert result == "Project Alpha is active." - - @pytest.mark.asyncio - async def test_handle_with_create_project_tool_call(self) -> None: - with patch("app.agents.project_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm_with_tool_call( - "create_project", - {"name": "Pippo"}, - "Project 'Pippo' created.", - ) - result = await ProjectAgent().handle("create project Pippo", {}) - assert result == "Project 'Pippo' created." - - @pytest.mark.asyncio - async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.project_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Done.") - result = await ProjectAgent().handle("archive old project", {}) - assert isinstance(result, str) - - -class TestProjectAgentTools: - @pytest.mark.asyncio - async def test_list_projects_defaults(self) -> None: - from app.agents.project_agent import list_projects - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_projects.ainvoke({}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "select" - assert call_kwargs["table"] == "projects" - assert call_kwargs["filters"]["includeArchived"] is False - assert result == "No projects found." - - @pytest.mark.asyncio - async def test_list_projects_include_archived(self) -> None: - from app.agents.project_agent import list_projects - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - await list_projects.ainvoke({"include_archived": 1}) - assert m.call_args.kwargs["filters"]["includeArchived"] is True - - @pytest.mark.asyncio - async def test_list_all_projects(self) -> None: - from app.agents.project_agent import list_all_projects - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_all_projects.ainvoke({}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "select" - assert call_kwargs["table"] == "projects" - assert result == "No projects found." - - @pytest.mark.asyncio - async def test_get_project(self) -> None: - from app.agents.project_agent import get_project - fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None} - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await get_project.ainvoke({"project_id": "p1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "get" - assert call_kwargs["table"] == "projects" - assert call_kwargs["data"]["id"] == "p1" - assert "Alpha" in result - - @pytest.mark.asyncio - async def test_create_project_name_only(self) -> None: - from app.agents.project_agent import create_project - fake_row = {"id": "p1", "name": "Alpha"} - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await create_project.ainvoke({"name": "Alpha"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "insert" - assert call_kwargs["data"]["name"] == "Alpha" - assert call_kwargs["data"]["clientId"] is None - assert "Alpha" in result - - @pytest.mark.asyncio - async def test_create_project_with_client(self) -> None: - from app.agents.project_agent import create_project - fake_row = {"id": "p1", "name": "Beta"} - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await create_project.ainvoke({"name": "Beta", "client_id": "cl1"}) - assert m.call_args.kwargs["data"]["clientId"] == "cl1" - - @pytest.mark.asyncio - async def test_update_project_archive(self) -> None: - from app.agents.project_agent import update_project - fake_row = {"id": "p1", "name": "Alpha", "status": "archived"} - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await update_project.ainvoke({"project_id": "p1", "status": "archived"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "update" - assert call_kwargs["data"]["id"] == "p1" - assert call_kwargs["data"]["updates"]["status"] == "archived" - assert "p1" in result - - @pytest.mark.asyncio - async def test_update_project_empty_updates(self) -> None: - from app.agents.project_agent import update_project - fake_row = {"id": "p1", "name": "Alpha", "status": "active"} - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await update_project.ainvoke({"project_id": "p1"}) - assert m.call_args.kwargs["data"]["updates"] == {} - - @pytest.mark.asyncio - async def test_delete_project(self) -> None: - from app.agents.project_agent import delete_project - with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"deleted": True} - result = await delete_project.ainvoke({"project_id": "p1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "delete" - assert call_kwargs["data"]["id"] == "p1" - assert "p1" in result - - -# ── NoteAgent ───────────────────────────────────────────────────────── - - -class TestNoteAgent: - def test_name(self) -> None: - assert NoteAgent().get_name() == "note_agent" - - def test_description(self) -> None: - assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete" - - def test_get_tools_count(self) -> None: - assert len(NoteAgent().get_tools()) == 5 - - def test_tool_names(self) -> None: - names = {t.name for t in NoteAgent().get_tools()} - assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"} - - @pytest.mark.asyncio - async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.note_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Note created.") - result = await NoteAgent().handle("create a note", {}) - assert result == "Note created." - - @pytest.mark.asyncio - async def test_handle_with_create_note_tool_call(self) -> None: - with patch("app.agents.note_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm_with_tool_call( - "create_note", - {"title": "Daily log", "content": "# Today\nAll good."}, - "Note 'Daily log' created.", - ) - result = await NoteAgent().handle("log today's progress", {}) - assert result == "Note 'Daily log' created." - - @pytest.mark.asyncio - async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.note_agent.get_llm") as mock_cls: - mock_cls.return_value = _mock_llm("Done.") - result = await NoteAgent().handle("show notes", {}) - assert isinstance(result, str) - - -class TestNoteAgentTools: - @pytest.mark.asyncio - async def test_list_notes_no_project(self) -> None: - from app.agents.note_agent import list_notes - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - result = await list_notes.ainvoke({}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "select" - assert call_kwargs["table"] == "notes" - assert call_kwargs["filters"]["projectId"] is None - assert result == "No notes found." - - @pytest.mark.asyncio - async def test_list_notes_with_project(self) -> None: - from app.agents.note_agent import list_notes - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"rows": []} - await list_notes.ainvoke({"project_id": "p1"}) - assert m.call_args.kwargs["filters"]["projectId"] == "p1" - - @pytest.mark.asyncio - async def test_get_note(self) -> None: - from app.agents.note_agent import get_note - fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."} - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - result = await get_note.ainvoke({"note_id": "n1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "get" - assert call_kwargs["table"] == "notes" - assert call_kwargs["data"]["id"] == "n1" - assert "Daily log" in result - - @pytest.mark.asyncio - async def test_create_note_minimal(self) -> None: - from app.agents.note_agent import create_note - fake_row = {"id": "n1", "title": "Daily log", "projectId": None} - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \ - patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me: - m.return_value = {"row": fake_row} - me.return_value = [0.0] * 1536 - result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."}) - # First call: insert; second call: vector_upsert - first_call = m.call_args_list[0].kwargs - assert first_call["action"] == "insert" - assert first_call["table"] == "notes" - assert first_call["data"]["title"] == "Daily log" - assert first_call["data"]["content"] == "# Today\nAll good." - assert first_call["data"]["projectId"] is None - assert "Daily log" in result - - @pytest.mark.asyncio - async def test_create_note_with_project(self) -> None: - from app.agents.note_agent import create_note - fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"} - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \ - patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me: - m.return_value = {"row": fake_row} - me.return_value = [0.0] * 1536 - await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"}) - first_call = m.call_args_list[0].kwargs - assert first_call["data"]["projectId"] == "p1" - - @pytest.mark.asyncio - async def test_update_note_content_only(self) -> None: - from app.agents.note_agent import update_note - fake_row = {"id": "n1", "title": "Daily log", "projectId": None} - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \ - patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me: - m.return_value = {"row": fake_row} - me.return_value = [0.0] * 1536 - result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"}) - first_call = m.call_args_list[0].kwargs - assert first_call["action"] == "update" - assert first_call["data"]["id"] == "n1" - assert first_call["data"]["updates"]["content"] == "# Updated content" - assert "title" not in first_call["data"]["updates"] - assert "n1" in result - - @pytest.mark.asyncio - async def test_update_note_empty_updates(self) -> None: - from app.agents.note_agent import update_note - fake_row = {"id": "n1", "title": "Daily log", "projectId": None} - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"row": fake_row} - await update_note.ainvoke({"note_id": "n1"}) - assert m.call_args.kwargs["data"]["updates"] == {} - - @pytest.mark.asyncio - async def test_delete_note(self) -> None: - from app.agents.note_agent import delete_note - with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m: - m.return_value = {"deleted": True} - result = await delete_note.ainvoke({"note_id": "n1"}) - call_kwargs = m.call_args.kwargs - assert call_kwargs["action"] == "delete" - assert call_kwargs["table"] == "notes" - assert call_kwargs["data"]["id"] == "n1" - assert "n1" in result diff --git a/tests/test_execution_plan.py b/tests/test_execution_plan.py deleted file mode 100644 index 06a2bfa..0000000 --- a/tests/test_execution_plan.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache.""" - -from __future__ import annotations - -import pytest - -from app.core.execution_plan import ( - ExecutionPlanBuilder, - PlanCache, - PromptTemplateRegistry, - plan_cache, - template_registry, -) -from app.schemas import ExecutionPlan - - -# ── PromptTemplateRegistry ──────────────────────────────────────────── - - -class TestPromptTemplateRegistry: - def test_register_and_get(self) -> None: - reg = PromptTemplateRegistry() - reg.register("tpl_foo", "You are a foo agent.") - assert reg.get("tpl_foo") == "You are a foo agent." - - def test_get_unknown_raises_key_error(self) -> None: - reg = PromptTemplateRegistry() - with pytest.raises(KeyError, match="tpl_missing"): - reg.get("tpl_missing") - - def test_has_returns_true_for_registered(self) -> None: - reg = PromptTemplateRegistry() - reg.register("tpl_x", "prompt text") - assert reg.has("tpl_x") is True - - def test_has_returns_false_for_unregistered(self) -> None: - reg = PromptTemplateRegistry() - assert reg.has("tpl_missing") is False - - def test_list_ids_returns_all_registered_ids(self) -> None: - reg = PromptTemplateRegistry() - reg.register("tpl_a", "a") - reg.register("tpl_b", "b") - assert set(reg.list_ids()) == {"tpl_a", "tpl_b"} - - def test_list_ids_does_not_return_prompt_text(self) -> None: - reg = PromptTemplateRegistry() - reg.register("tpl_secret", "top secret prompt") - ids = reg.list_ids() - assert "top secret prompt" not in ids - - def test_overwrite_existing_template(self) -> None: - reg = PromptTemplateRegistry() - reg.register("tpl_x", "v1") - reg.register("tpl_x", "v2") - assert reg.get("tpl_x") == "v2" - - def test_empty_registry_has_no_ids(self) -> None: - reg = PromptTemplateRegistry() - assert reg.list_ids() == [] - - -# ── ExecutionPlanBuilder ────────────────────────────────────────────── - - -class TestExecutionPlanBuilder: - def test_builds_empty_plan(self) -> None: - plan = ExecutionPlanBuilder("task_agent").build() - assert plan.agent == "task_agent" - assert plan.steps == [] - - def test_add_step_basic(self) -> None: - plan = ( - ExecutionPlanBuilder("task_agent") - .add_step("create_task", {"priority": "high"}) - .build() - ) - assert len(plan.steps) == 1 - assert plan.steps[0].action == "create_task" - assert plan.steps[0].variables == {"priority": "high"} - assert plan.steps[0].prompt_template is None - assert plan.steps[0].data_from_step is None - - def test_add_step_no_params(self) -> None: - plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build() - assert plan.steps[0].variables is None - - def test_add_llm_step(self) -> None: - plan = ( - ExecutionPlanBuilder("task_agent") - .add_llm_step("tpl_task_default", {"message": "hi"}) - .build() - ) - assert plan.steps[0].action == "llm" - assert plan.steps[0].prompt_template == "tpl_task_default" - assert plan.steps[0].variables == {"message": "hi"} - - def test_add_llm_step_no_variables(self) -> None: - plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build() - assert plan.steps[0].variables is None - - def test_add_data_step(self) -> None: - plan = ( - ExecutionPlanBuilder("task_agent") - .add_step("fetch_data") - .add_data_step("transform", data_from_step=0) - .build() - ) - assert plan.steps[1].action == "transform" - assert plan.steps[1].data_from_step == 0 - - def test_fluent_chaining_returns_builder(self) -> None: - builder = ExecutionPlanBuilder("analytics_agent") - result = builder.add_step("a") - assert result is builder - - def test_fluent_chain_multiple_steps(self) -> None: - plan = ( - ExecutionPlanBuilder("analytics_agent") - .add_llm_step("tpl_analytics_default") - .add_step("format_output") - .add_data_step("store", data_from_step=0) - .build() - ) - assert len(plan.steps) == 3 - - def test_build_validates_data_from_step_out_of_range(self) -> None: - with pytest.raises(ValueError, match="data_from_step"): - ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build() - - def test_build_validates_data_from_step_self_reference(self) -> None: - """data_from_step=0 on the first step (index 0) is invalid.""" - with pytest.raises(ValueError, match="data_from_step"): - ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build() - - def test_build_validates_data_from_step_negative(self) -> None: - with pytest.raises(ValueError, match="data_from_step"): - ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build() - - def test_valid_data_from_step_at_index_two(self) -> None: - plan = ( - ExecutionPlanBuilder("task_agent") - .add_step("step0") - .add_step("step1") - .add_data_step("step2", data_from_step=1) - .build() - ) - assert plan.steps[2].data_from_step == 1 - - def test_data_from_step_zero_valid_at_index_one(self) -> None: - plan = ( - ExecutionPlanBuilder("task_agent") - .add_step("step0") - .add_data_step("step1", data_from_step=0) - .build() - ) - assert plan.steps[1].data_from_step == 0 - - def test_build_returns_new_plan_each_call(self) -> None: - builder = ExecutionPlanBuilder("task_agent").add_step("do_thing") - plan1 = builder.build() - plan2 = builder.build() - assert plan1 is not plan2 - assert plan1.steps == plan2.steps - - def test_plan_is_execution_plan_instance(self) -> None: - plan = ExecutionPlanBuilder("task_agent").build() - assert isinstance(plan, ExecutionPlan) - - -# ── PlanCache ───────────────────────────────────────────────────────── - - -class TestPlanCache: - def _plan(self, agent: str = "a") -> ExecutionPlan: - return ExecutionPlanBuilder(agent).build() - - def test_cache_and_get(self) -> None: - cache = PlanCache() - plan = self._plan() - cache.cache_plan("key1", plan) - assert cache.get_plan("key1") is plan - - def test_get_missing_returns_none(self) -> None: - cache = PlanCache() - assert cache.get_plan("nonexistent") is None - - def test_get_all_playbooks_empty(self) -> None: - cache = PlanCache() - assert cache.get_all_playbooks() == [] - - def test_get_all_playbooks_returns_all_stored(self) -> None: - cache = PlanCache() - p1, p2 = self._plan("a"), self._plan("b") - cache.cache_plan("k1", p1) - cache.cache_plan("k2", p2) - playbooks = cache.get_all_playbooks() - assert len(playbooks) == 2 - assert p1 in playbooks - assert p2 in playbooks - - def test_lru_evicts_oldest_entry(self) -> None: - cache = PlanCache(maxsize=2) - p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c") - cache.cache_plan("k1", p1) - cache.cache_plan("k2", p2) - cache.cache_plan("k3", p3) # k1 should be evicted - assert cache.get_plan("k1") is None - assert cache.get_plan("k2") is p2 - assert cache.get_plan("k3") is p3 - - def test_lru_access_updates_recency(self) -> None: - cache = PlanCache(maxsize=2) - p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c") - cache.cache_plan("k1", p1) - cache.cache_plan("k2", p2) - cache.get_plan("k1") # k1 is now most-recently used - cache.cache_plan("k3", p3) # k2 should be evicted (LRU) - assert cache.get_plan("k1") is p1 - assert cache.get_plan("k2") is None - assert cache.get_plan("k3") is p3 - - def test_overwrite_existing_key(self) -> None: - cache = PlanCache() - p1, p2 = self._plan("a"), self._plan("b") - cache.cache_plan("same_key", p1) - cache.cache_plan("same_key", p2) - assert cache.get_plan("same_key") is p2 - assert len(cache.get_all_playbooks()) == 1 - - def test_overwrite_does_not_consume_capacity(self) -> None: - cache = PlanCache(maxsize=2) - p1, p2 = self._plan("a"), self._plan("b") - cache.cache_plan("k1", p1) - cache.cache_plan("k1", p2) # overwrite, not a new slot - cache.cache_plan("k2", p1) # should fit without eviction - assert cache.get_plan("k1") is p2 - assert cache.get_plan("k2") is p1 - - -# ── Module-level singletons ─────────────────────────────────────────── - - -class TestModuleSingletons: - def test_template_registry_has_all_agent_defaults(self) -> None: - for agent in ("task_agent", "timeline_agent", "project_agent", "note_agent"): - assert template_registry.has(f"tpl_{agent}_default"), ( - f"Missing template: tpl_{agent}_default" - ) - - def test_template_registry_has_operation_templates(self) -> None: - assert template_registry.has("tpl_task_extract_from_project") - assert template_registry.has("tpl_note_weekly_summary") - - def test_template_registry_get_returns_non_empty_string(self) -> None: - text = template_registry.get("tpl_task_agent_default") - assert isinstance(text, str) - assert len(text) > 0 - - def test_plan_cache_has_prebuilt_playbooks(self) -> None: - assert len(plan_cache.get_all_playbooks()) >= 2 - - def test_playbook_create_tasks_from_project(self) -> None: - plan = plan_cache.get_plan("create_tasks_from_project") - assert plan is not None - assert plan.agent == "project_agent" - assert len(plan.steps) == 2 - assert plan.steps[0].prompt_template == "tpl_task_extract_from_project" - assert plan.steps[1].data_from_step == 0 - - def test_playbook_generate_weekly_note(self) -> None: - plan = plan_cache.get_plan("generate_weekly_note") - assert plan is not None - assert plan.agent == "note_agent" - assert len(plan.steps) == 2 - assert plan.steps[0].prompt_template == "tpl_note_weekly_summary" - assert plan.steps[1].data_from_step == 0 - - def test_playbook_steps_have_no_raw_prompt_text(self) -> None: - """Plans must not embed prompt text — only template IDs.""" - for plan in plan_cache.get_all_playbooks(): - for step in plan.steps: - if step.prompt_template is not None: - assert step.prompt_template.startswith("tpl_"), ( - f"prompt_template looks like raw text: {step.prompt_template!r}" - ) diff --git a/tests/test_memory_middleware.py b/tests/test_memory_middleware.py index ea5f558..0e5ff09 100644 --- a/tests/test_memory_middleware.py +++ b/tests/test_memory_middleware.py @@ -250,15 +250,15 @@ def test_home_request_calls_memory_middleware(client): token = make_jwt("power", user_id=USER_ID) session_id = str(uuid.uuid4()) - async def _mock_stream(user_id, message, context, reg=None): + async def _mock_stream(user_id, message, context, db_session_factory=None): # Verify memory context was injected assert context.get("core_memory") == {"tz": "UTC"} - yield "task_agent", "" - yield "task_agent", '{"type": "text", "content": "Done"}' + yield ("token", "Done") + yield ("mutations", []) with ( patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware), - patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream), + patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream), ): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 8721bbc..576a145 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -20,7 +20,6 @@ from jose import jwt from app.config.settings import settings from app.db import get_session from app.main import app -from app.schemas import ChatResponse from tests.conftest import TEST_USER_IDS # --------------------------------------------------------------------------- @@ -50,7 +49,6 @@ _CHAT_BODY = { "recent_tasks": [], "conversation_history": [], }, - "execution_mode": "direct", } @@ -240,7 +238,7 @@ class TestRateLimitMiddleware: class TestSanitizerMiddleware: - """Mock ``orchestrate`` to inject controlled strings into chat responses.""" + """Mock ``run_home`` to inject controlled strings into chat responses.""" _CHAT_PATH = "/api/v1/chat" @@ -248,11 +246,10 @@ class TestSanitizerMiddleware: return _make_jwt(user_id=str(uuid.uuid4()), tier="pro") def _post_chat(self, client: TestClient, response_text: str) -> dict: - mock_response = ChatResponse(response=response_text, actions=[]) with patch( - "app.api.routes.chat.orchestrate", + "app.api.routes.chat.run_home", new_callable=AsyncMock, - return_value=mock_response, + return_value=response_text, ): resp = client.post( self._CHAT_PATH, diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py deleted file mode 100644 index 07576d4..0000000 --- a/tests/test_orchestrator.py +++ /dev/null @@ -1,347 +0,0 @@ -"""Integration tests for the orchestrator module.""" - -from __future__ import annotations - -import json -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from app.core.agent_registry import AgentRegistry, ChatAgent -from app.core.orchestrator import ( - classify_intent, - orchestrate, - orchestrate_stream, - route_pipeline, - route_single, -) -from app.schemas import ChatRequest, ChatResponse, ExecutionPlan - - -# ── Stub agents ────────────────────────────────────────────────────── - - -class _TaskAgent(ChatAgent): - def get_name(self) -> str: - return "task_agent" - - def get_description(self) -> str: - return "Manages tasks: create, update, list, suggest" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return f"task: {query}" - - -class _CalendarAgent(ChatAgent): - def get_name(self) -> str: - return "calendar_agent" - - def get_description(self) -> str: - return "Calendar management: events, conflicts, scheduling" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return f"calendar: {query}" - - -# ── Helpers ────────────────────────────────────────────────────────── - - -def _mock_llm(response_text: str) -> MagicMock: - """Return a mock LLM that always produces *response_text*.""" - msg = MagicMock() - msg.content = response_text - llm = MagicMock() - llm.ainvoke = AsyncMock(return_value=msg) - return llm - - -# ── Fixtures ───────────────────────────────────────────────────────── - - -@pytest.fixture(autouse=True) -def _fresh_registry(): - """Reset the AgentRegistry singleton between tests.""" - AgentRegistry._instance = None - yield - AgentRegistry._instance = None - - -@pytest.fixture() -def reg() -> AgentRegistry: - r = AgentRegistry() - r.register(_TaskAgent) - r.register(_CalendarAgent) - return r - - -# ── classify_intent ─────────────────────────────────────────────────── - - -class TestClassifyIntent: - @pytest.mark.asyncio - async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - result = await classify_intent("add a task", {}, reg) - assert result == "task_agent" - - @pytest.mark.asyncio - async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("calendar_agent") - result = await classify_intent("schedule a meeting", {}, reg) - assert result == "calendar_agent" - - @pytest.mark.asyncio - async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("nonexistent_agent") - result = await classify_intent("do something", {}, reg) - assert result == "task_agent" - - @pytest.mark.asyncio - async def test_empty_registry_returns_fallback_without_llm_call(self) -> None: - empty_reg = AgentRegistry() - # No LLM should be instantiated — early return path - with patch("app.core.orchestrator._make_llm") as mock_cls: - result = await classify_intent("anything", {}, empty_reg) - mock_cls.assert_not_called() - assert result == "task_agent" - - @pytest.mark.asyncio - async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm(" task_agent \n") - result = await classify_intent("create task", {}, reg) - assert result == "task_agent" - - -# ── route_single ───────────────────────────────────────────────────── - - -class TestRouteSingle: - @pytest.mark.asyncio - async def test_returns_chat_response(self, reg: AgentRegistry) -> None: - result = await route_single("task_agent", "create a task", {}, reg) - assert isinstance(result, ChatResponse) - - @pytest.mark.asyncio - async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None: - result = await route_single("task_agent", "create a task", {}, reg) - assert result.response == "task: create a task" - - @pytest.mark.asyncio - async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None: - with pytest.raises(KeyError): - await route_single("nonexistent", "hello", {}, reg) - - @pytest.mark.asyncio - async def test_actions_default_empty(self, reg: AgentRegistry) -> None: - result = await route_single("task_agent", "hi", {}, reg) - assert result.actions == [] - - -# ── route_pipeline ──────────────────────────────────────────────────── - - -class TestRoutePipeline: - @pytest.mark.asyncio - async def test_returns_chat_response(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("synthesized result") - result = await route_pipeline( - ["task_agent", "calendar_agent"], "plan my week", {}, reg - ) - assert isinstance(result, ChatResponse) - - @pytest.mark.asyncio - async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("synthesized result") - result = await route_pipeline( - ["task_agent", "calendar_agent"], "plan my week", {}, reg - ) - assert result.response == "synthesized result" - - @pytest.mark.asyncio - async def test_passes_previous_results_to_subsequent_agents( - self, reg: AgentRegistry - ) -> None: - """Each agent after the first should receive prior outputs in context.""" - received_contexts: list[dict[str, Any]] = [] - - class _CapturingAgent(ChatAgent): - def get_name(self) -> str: - return "capture" - - def get_description(self) -> str: - return "captures context for testing" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - received_contexts.append(dict(context)) - return "captured" - - reg.register(_CapturingAgent) - - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("done") - await route_pipeline(["task_agent", "capture"], "hi", {}, reg) - - # The second agent (capture) must have received previous results - assert len(received_contexts) == 1 - assert "previous_results" in received_contexts[0] - assert received_contexts[0]["previous_results"] == ["task: hi"] - - @pytest.mark.asyncio - async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("single result") - result = await route_pipeline(["task_agent"], "one agent", {}, reg) - assert result.response == "single result" - - -# ── orchestrate ─────────────────────────────────────────────────────── - - -class TestOrchestrate: - @pytest.mark.asyncio - async def test_direct_mode_returns_chat_response( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="add a task", execution_mode="direct") - result = await orchestrate(request, reg) - assert isinstance(result, ChatResponse) - - @pytest.mark.asyncio - async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="add a task", execution_mode="direct") - result = await orchestrate(request, reg) - assert isinstance(result, ChatResponse) - assert result.response == "task: add a task" - - @pytest.mark.asyncio - async def test_plan_mode_returns_execution_plan( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="plan my tasks", execution_mode="plan") - result = await orchestrate(request, reg) - assert isinstance(result, ExecutionPlan) - - @pytest.mark.asyncio - async def test_plan_mode_agent_matches_classified( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("calendar_agent") - request = ChatRequest( - message="schedule something", execution_mode="plan" - ) - result = await orchestrate(request, reg) - assert isinstance(result, ExecutionPlan) - assert result.agent == "calendar_agent" - - @pytest.mark.asyncio - async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="plan tasks", execution_mode="plan") - result = await orchestrate(request, reg) - assert isinstance(result, ExecutionPlan) - assert len(result.steps) >= 1 - - @pytest.mark.asyncio - async def test_plan_mode_template_id_contains_agent_name( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="plan tasks", execution_mode="plan") - result = await orchestrate(request, reg) - assert isinstance(result, ExecutionPlan) - assert result.steps[0].prompt_template is not None - assert "task_agent" in result.steps[0].prompt_template - - @pytest.mark.asyncio - async def test_default_execution_mode_is_direct( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - # execution_mode defaults to "direct" - request = ChatRequest(message="help me") - result = await orchestrate(request, reg) - assert isinstance(result, ChatResponse) - - -# ── orchestrate_stream ──────────────────────────────────────────────── - - -class TestOrchestrateStream: - @pytest.mark.asyncio - async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="add a task", execution_mode="direct") - chunks = [chunk async for chunk in orchestrate_stream(request, reg)] - assert len(chunks) >= 1 - - @pytest.mark.asyncio - async def test_all_chunks_are_plain_text( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="add a task", execution_mode="direct") - chunks = [chunk async for chunk in orchestrate_stream(request, reg)] - - # orchestrate_stream yields plain text chunks only — no JSON final frame - for chunk in chunks: - assert isinstance(chunk, str) - - @pytest.mark.asyncio - async def test_concatenated_chunks_equal_full_response( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest(message="create a task", execution_mode="direct") - chunks = [chunk async for chunk in orchestrate_stream(request, reg)] - - full_text = "".join(chunks) - assert full_text == "task: create a task" - - @pytest.mark.asyncio - async def test_text_chunks_before_final_frame( - self, reg: AgentRegistry - ) -> None: - with patch("app.core.orchestrator._make_llm") as mock_cls: - mock_cls.return_value = _mock_llm("task_agent") - request = ChatRequest( - message="x" * 200, execution_mode="direct" - ) # long enough to produce multiple chunks - chunks = [chunk async for chunk in orchestrate_stream(request, reg)] - - # All but the last chunk should be plain text (not valid final JSON) - non_final = chunks[:-1] - for chunk in non_final: - try: - parsed = json.loads(chunk) - assert parsed.get("done") is not True - except json.JSONDecodeError: - pass # plain text chunk — expected diff --git a/tests/test_orchestrator_v3.py b/tests/test_orchestrator_v3.py deleted file mode 100644 index fccb8ab..0000000 --- a/tests/test_orchestrator_v3.py +++ /dev/null @@ -1,236 +0,0 @@ -"""Tests for v3 orchestrator functions (Step 3).""" - -from __future__ import annotations - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from typing import Any - -from app.core.agent_registry import ChatAgent, AgentRegistry -from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream - - -# ── Minimal agent for testing ───────────────────────────────────────── - - -class _FixedAgent(ChatAgent): - def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._name = name - self._tokens = tokens or ["Hello", " world"] - - def get_name(self) -> str: - return self._name - - def get_description(self) -> str: - return "Fixed agent for tests" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return "".join(self._tokens) - - async def handle_stream(self, query: str, context: dict[str, Any]): - for tok in self._tokens: - yield tok - - -# ── Mock registry factory ───────────────────────────────────────────── - - -def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock: - reg = MagicMock(spec=AgentRegistry) - reg.list_agents.return_value = [{"name": agent_name, "description": "test"}] - reg.get.return_value = agent - return reg - - -# ── orchestrate_v3 ──────────────────────────────────────────────────── - - -@pytest.mark.asyncio -async def test_orchestrate_v3_returns_agent_name_and_instance(): - agent = _FixedAgent("task_agent") - reg = _make_registry("task_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): - name, inst = await orchestrate_v3( - user_id="u-1", message="fix a bug", context={}, reg=reg - ) - - assert name == "task_agent" - assert inst is agent - - -@pytest.mark.asyncio -async def test_orchestrate_v3_classify_called_with_message_and_context(): - agent = _FixedAgent("note_agent") - reg = _make_registry("note_agent", agent) - ctx = {"some": "context"} - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify: - await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg) - - mock_classify.assert_awaited_once() - call_args = mock_classify.call_args - assert call_args[0][0] == "take a note" - assert call_args[0][1] == ctx - - -@pytest.mark.asyncio -async def test_orchestrate_v3_uses_default_registry_when_none(): - agent = _FixedAgent("task_agent") - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \ - patch("app.core.orchestrator._default_registry") as mock_reg: - mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}] - mock_reg.get.return_value = agent - name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={}) - - assert name == "task_agent" - assert inst is agent - - -@pytest.mark.asyncio -async def test_orchestrate_v3_get_called_with_agent_name(): - agent = _FixedAgent("timeline_agent") - reg = _make_registry("timeline_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")): - await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg) - - reg.get.assert_called_once_with("timeline_agent") - - -# ── orchestrate_v3_stream ───────────────────────────────────────────── - - -async def _collect(gen) -> list[tuple[str, str]]: - results: list[tuple[str, str]] = [] - async for item in gen: - results.append(item) - return results - - -@pytest.mark.asyncio -async def test_orchestrate_v3_stream_first_yield_is_domain_signal(): - agent = _FixedAgent("task_agent", tokens=["token1"]) - reg = _make_registry("task_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): - gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) - results = await _collect(gen) - - # First item must be (agent_name, "") — domain signal - assert results[0] == ("task_agent", "") - - -@pytest.mark.asyncio -async def test_orchestrate_v3_stream_yields_agent_name_with_tokens(): - agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"]) - reg = _make_registry("task_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): - gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) - results = await _collect(gen) - - # All items are (agent_name, token) pairs - assert all(name == "task_agent" for name, _ in results) - tokens = [tok for _, tok in results] - assert tokens[0] == "" # domain signal - assert tokens[1:] == ["Hello", " ", "world"] - - -@pytest.mark.asyncio -async def test_orchestrate_v3_stream_different_agent(): - agent = _FixedAgent("note_agent", tokens=["note"]) - reg = _make_registry("note_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")): - gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg) - results = await _collect(gen) - - assert results[0] == ("note_agent", "") - assert ("note_agent", "note") in results - - -@pytest.mark.asyncio -async def test_orchestrate_v3_stream_uses_default_registry_when_none(): - agent = _FixedAgent("task_agent", tokens=["x"]) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \ - patch("app.core.orchestrator._default_registry") as mock_reg: - mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}] - mock_reg.get.return_value = agent - gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}) - results = await _collect(gen) - - assert results[0][0] == "task_agent" - - -@pytest.mark.asyncio -async def test_orchestrate_v3_stream_empty_token_list(): - """Agent with no tokens still emits the domain signal.""" - - class _EmptyAgent(_FixedAgent): - async def handle_stream(self, query: str, context: dict[str, Any]): - return - yield # makes it a generator - - agent = _EmptyAgent("task_agent", tokens=[]) - reg = _make_registry("task_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): - gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) - results = await _collect(gen) - - assert results == [("task_agent", "")] # only domain signal - - -@pytest.mark.asyncio -async def test_orchestrate_v3_stream_full_text_correct(): - """Concatenating all non-domain tokens reconstructs the full response.""" - tokens = ["The", " ", "task", " ", "is", " ", "done."] - agent = _FixedAgent("task_agent", tokens=tokens) - reg = _make_registry("task_agent", agent) - - with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")): - gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg) - results = await _collect(gen) - - text = "".join(tok for _, tok in results[1:]) # skip domain signal - assert text == "The task is done." - - -# ── handle_stream default implementation ───────────────────────────── - - -@pytest.mark.asyncio -async def test_handle_stream_default_yields_full_response(): - """Default handle_stream yields handle() result as a single chunk.""" - - class _SimpleAgent(ChatAgent): - def get_name(self) -> str: - return "_simple" - - def get_description(self) -> str: - return "" - - def get_tools(self) -> list[Any]: - return [] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - return "simple response" - - agent = _SimpleAgent() - tokens = [tok async for tok in agent.handle_stream("q", {})] - assert tokens == ["simple response"] - - -@pytest.mark.asyncio -async def test_handle_stream_override_used_by_stream(): - """_FixedAgent.handle_stream override yields individual tokens.""" - agent = _FixedAgent("t", tokens=["a", "b", "c"]) - tokens = [tok async for tok in agent.handle_stream("q", {})] - assert tokens == ["a", "b", "c"] diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py index bfc5c1c..817f887 100644 --- a/tests/test_output_formatter.py +++ b/tests/test_output_formatter.py @@ -16,15 +16,15 @@ from app.schemas import ( # ── helpers ─────────────────────────────────────────────────────────────────── -async def _stream(*pairs: tuple[str, str]): - """Async generator that yields (agent_name, token) pairs.""" - for pair in pairs: - yield pair +async def _stream(*events: tuple[str, object]): + """Async generator that yields (event_type, data) tuples.""" + for event in events: + yield event -async def collect(formatter, token_stream): +async def collect(formatter, event_stream): frames = [] - async for frame in formatter.format(token_stream): + async for frame in formatter.format(event_stream): frames.append(frame) return frames @@ -32,13 +32,14 @@ async def collect(formatter, token_stream): # ── HomeFormatter ───────────────────────────────────────────────────────────── @pytest.mark.asyncio -async def test_home_formatter_text_block(): +async def test_home_formatter_text_token(): req_id = "req-1" - tokens = [ - ("task_agent", '{"type": "text", "content": "Hello world"}'), + events = [ + ("token", "Hello world"), + ("mutations", []), ] - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(*tokens)) + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsStreamStart) assert frames[0].request_id == req_id @@ -48,104 +49,94 @@ async def test_home_formatter_text_block(): @pytest.mark.asyncio -async def test_home_formatter_chart_block(): +async def test_home_formatter_entity_ref_from_tool_end(): req_id = "req-2" - chart_json = ( - '{"type": "chart", "chartType": "bar", ' - '"title": "Tasks", "data": [{"x": 1}], ' - '"config": {"x": {"label": "X", "color": "#fff"}}}' - ) - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(("task_agent", chart_json))) + events = [ + ("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}), + ("token", "Here are your tasks."), + ("mutations", []), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] assert len(block_frames) == 1 - assert block_frames[0].block_type == "chart" - assert block_frames[0].data["chartType"] == "bar" + assert block_frames[0].block_type == "entity_ref" + assert block_frames[0].data["entity"] == "tasks" + assert block_frames[0].data["result"] == "Found 3 tasks." @pytest.mark.asyncio -async def test_home_formatter_invalid_chart_skipped(): +async def test_home_formatter_unknown_agent_no_block(): req_id = "req-3" - bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}' - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(("task_agent", bad_chart))) + events = [ + ("tool_end", {"name": "unknown_agent", "result": "stuff"}), + ("mutations", []), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 0 # invalid chart skipped + assert len(block_frames) == 0 # unknown agent → no entity mapping @pytest.mark.asyncio -async def test_home_formatter_entity_ref_resolved(): +async def test_home_formatter_mutations_in_stream_end(): req_id = "req-4" - tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}] - entity_json = '{"type": "entity_ref", "entity": "task"}' - formatter = HomeFormatter(request_id=req_id, tool_results=tool_results) - frames = await collect(formatter, _stream(("task_agent", entity_json))) + muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}] + events = [ + ("token", "Done"), + ("mutations", muts), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 1 - assert block_frames[0].data["entity"] == "task" - assert block_frames[0].data["items"][0]["id"] == "t1" - - -@pytest.mark.asyncio -async def test_home_formatter_entity_ref_missing_skipped(): - req_id = "req-5" - entity_json = '{"type": "entity_ref", "entity": "task"}' - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(("task_agent", entity_json))) - - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 0 # no tool results → skipped - - -@pytest.mark.asyncio -async def test_home_formatter_table_block(): - req_id = "req-6" - table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}' - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(("task_agent", table_json))) - - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 1 - assert block_frames[0].block_type == "table" - - -@pytest.mark.asyncio -async def test_home_formatter_timeline_block(): - req_id = "req-7" - timeline_json = '{"type": "timeline", "timelines": [{"id": "c1", "title": "M1", "date": 123}]}' - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(("task_agent", timeline_json))) - - block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] - assert len(block_frames) == 1 - assert block_frames[0].block_type == "timeline" + end_frame = frames[-1] + assert isinstance(end_frame, WsStreamEnd) + assert len(end_frame.mutations) == 1 + assert end_frame.mutations[0]["action"] == "insert" @pytest.mark.asyncio async def test_home_formatter_frame_order(): """stream_start is first, stream_end is last.""" - req_id = "req-8" - formatter = HomeFormatter(request_id=req_id, tool_results=[]) - frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}'))) + req_id = "req-5" + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", []))) assert isinstance(frames[0], WsStreamStart) assert isinstance(frames[-1], WsStreamEnd) -# ── FloatingFormatter ──────────────────────────────────────────────────────────── +@pytest.mark.asyncio +async def test_home_formatter_multiple_tool_ends(): + req_id = "req-6" + events = [ + ("tool_end", {"name": "task_agent", "result": "3 tasks"}), + ("tool_end", {"name": "project_agent", "result": "2 projects"}), + ("token", "Overview done."), + ("mutations", []), + ] + formatter = HomeFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) + + block_frames = [f for f in frames if isinstance(f, WsStreamBlock)] + assert len(block_frames) == 2 + entities = {b.data["entity"] for b in block_frames} + assert entities == {"tasks", "projects"} + + +# ── FloatingFormatter ───────────────────────────────────────────────────────── @pytest.mark.asyncio -async def test_floating_formatter_domain_emitted_first(): +async def test_floating_formatter_domain_from_tool_end(): req_id = "pop-1" formatter = FloatingFormatter(request_id=req_id) - tokens = [ - ("task_agent", ""), # domain signal - ("task_agent", "Hello"), - ("task_agent", " there"), + events = [ + ("tool_end", {"name": "task_agent", "result": "ok"}), + ("token", "Hello"), + ("mutations", []), ] - frames = await collect(formatter, _stream(*tokens)) + frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "tasks" @@ -156,8 +147,12 @@ async def test_floating_formatter_domain_emitted_first(): async def test_floating_formatter_text_only(): req_id = "pop-2" formatter = FloatingFormatter(request_id=req_id) - tokens = [("timeline_agent", ""), ("timeline_agent", "Summary")] - frames = await collect(formatter, _stream(*tokens)) + events = [ + ("tool_end", {"name": "timeline_agent", "result": "done"}), + ("token", "Summary"), + ("mutations", []), + ] + frames = await collect(formatter, _stream(*events)) assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "timelines" @@ -171,11 +166,12 @@ async def test_floating_formatter_no_block_frames(): """FloatingFormatter must never emit WsStreamBlock.""" req_id = "pop-3" formatter = FloatingFormatter(request_id=req_id) - tokens = [ - ("note_agent", ""), - ("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'), + events = [ + ("tool_end", {"name": "note_agent", "result": "data"}), + ("token", "some text"), + ("mutations", []), ] - frames = await collect(formatter, _stream(*tokens)) + frames = await collect(formatter, _stream(*events)) assert not any(isinstance(f, WsStreamBlock) for f in frames) @@ -183,13 +179,37 @@ async def test_floating_formatter_no_block_frames(): async def test_floating_formatter_end_frame(): req_id = "pop-4" formatter = FloatingFormatter(request_id=req_id) - frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done"))) + events = [ + ("tool_end", {"name": "project_agent", "result": "ok"}), + ("token", "Done"), + ("mutations", []), + ] + frames = await collect(formatter, _stream(*events)) assert isinstance(frames[-1], WsStreamEnd) @pytest.mark.asyncio -async def test_floating_formatter_unknown_agent_defaults_to_tasks(): +async def test_floating_formatter_default_domain_on_early_token(): + """When the first event is a token (no tool_end yet), default to 'tasks'.""" req_id = "pop-5" formatter = FloatingFormatter(request_id=req_id) - frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi"))) + events = [("token", "hi"), ("mutations", [])] + frames = await collect(formatter, _stream(*events)) + assert isinstance(frames[0], WsFloatingDomain) assert frames[0].domain == "tasks" + + +@pytest.mark.asyncio +async def test_floating_formatter_mutations_in_stream_end(): + req_id = "pop-6" + muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}] + events = [ + ("token", "Updated"), + ("mutations", muts), + ] + formatter = FloatingFormatter(request_id=req_id) + frames = await collect(formatter, _stream(*events)) + + end_frame = frames[-1] + assert isinstance(end_frame, WsStreamEnd) + assert len(end_frame.mutations) == 1 diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 9c25d85..e73d704 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -88,7 +88,7 @@ class TestPluginRegistry: async def test_list_filter_by_query( self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] ) -> None: - result = await reg.list_plugins(db_session, query="time") + result = await reg.list_plugins(db_session, query="time tracker") assert result.total == 1 assert result.plugins[0].id == "plugin-time-tracker" diff --git a/tests/test_ws_unified.py b/tests/test_ws_unified.py index f4e6387..c770448 100644 --- a/tests/test_ws_unified.py +++ b/tests/test_ws_unified.py @@ -45,14 +45,16 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: return frames -async def _mock_home_stream(user_id, message, context, reg=None): - yield "task_agent", "" - yield "task_agent", '{"type": "text", "content": "Hello"}' +async def _mock_home_stream(user_id, message, context, db_session_factory=None): + yield "tool_end", {"name": "task_agent", "result": "Found tasks"} + yield "token", "Hello" + yield "mutations", [] -async def _mock_floating_stream(user_id, message, context, reg=None): - yield "task_agent", "" - yield "task_agent", "Here is a summary" +async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None): + yield "tool_end", {"name": "task_agent", "result": "ok"} + yield "token", "Here is a summary" + yield "mutations", [] # ── tests ───────────────────────────────────────────────────────────────────── @@ -61,7 +63,7 @@ def test_home_request_produces_stream_frames(client): """home_request → stream_start, stream_text+, stream_end.""" token = make_jwt("power", user_id=USER_ID) - with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream): + with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-1", "agent_ids": [] @@ -84,7 +86,7 @@ def test_floating_request_produces_domain_frame(client): """floating_request → floating_domain first, then stream_text*, stream_end.""" token = make_jwt("power", user_id=USER_ID) - with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream): + with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-2", "agent_ids": [] @@ -112,11 +114,12 @@ def test_home_request_request_id_propagated(client): token = make_jwt("power", user_id=USER_ID) req_id = "my-unique-req-id" - async def _stream(user_id, message, context, reg=None): - yield "note_agent", "" - yield "note_agent", '{"type": "text", "content": "ok"}' + async def _stream(user_id, message, context, db_session_factory=None): + yield "tool_end", {"name": "note_agent", "result": "ok"} + yield "token", "ok" + yield "mutations", [] - with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream): + with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-3", "agent_ids": []