refactor: replace orchestrator with LangGraph deep-agent supervisors
- Add app/core/deep_agent.py with Home and Floating supervisor graphs using LangGraph create_react_agent (hierarchical pattern) - Strip ChatAgent classes from all 4 agent files, keep @tool functions - Rewrite output_formatter.py for event-based (token/tool_end/mutations) stream - Update device_ws.py to use run_home_stream/run_floating_stream - Rewrite chat.py REST route to use run_home - Add update_core_memory tool to both supervisors - Add langgraph>=0.3.0 to requirements.txt - Remove orchestrator.py, execution_plan.py, agent_registry.py, plans.py - Remove PlanAction, PlanStep, ExecutionPlan, execution_mode from schemas - Update all affected tests to match new API - Remove 6 deprecated test files for deleted modules - Clean up stale docstrings referencing removed orchestrator
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Import all agent modules to trigger @registry.register decorators."""
|
||||
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||
|
||||
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||
|
||||
|
||||
@@ -1,31 +1,14 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import embed, get_llm
|
||||
from app.core.llm import embed
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||
"and delete Markdown notes in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - content is always Markdown; preserve formatting when updating\n"
|
||||
" - project_id is optional; link a note to a project when mentioned\n"
|
||||
" - When updating, call get_note first if you need to read existing content\n"
|
||||
" before appending or replacing sections\n"
|
||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||
" when the user is working within a specific project\n"
|
||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||
" is already in the note (retrieved via get_note)."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
@@ -122,23 +105,4 @@ async def delete_note(note_id: str) -> str:
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
@registry.register
|
||||
class NoteAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "note_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages notes: list, get, create, update, delete"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = get_llm()
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
|
||||
@@ -1,33 +1,13 @@
|
||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a project management assistant. You help users create, find,\n"
|
||||
"update, and archive projects in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: active, archived\n"
|
||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||
" derive it from context data — do not fabricate content\n"
|
||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||
" user wants a complete cross-client view including archived projects\n"
|
||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||
" list_projects if you only have a project name\n"
|
||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||
" only call delete_project when the user explicitly confirms deletion."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_projects(
|
||||
@@ -137,30 +117,4 @@ async def delete_project(project_id: str) -> str:
|
||||
return f"Project {project_id} permanently deleted."
|
||||
|
||||
|
||||
@registry.register
|
||||
class ProjectAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "project_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages projects: list, get, create, update, archive, delete"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = get_llm()
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
|
||||
@@ -1,35 +1,14 @@
|
||||
"""Task agent — full CRUD for tasks and task comments."""
|
||||
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a task management assistant for a project workspace.\n"
|
||||
"You create, update, list, and track tasks and their comments.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: todo, in_progress, done\n"
|
||||
" - priority must be one of: high, medium, low\n"
|
||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||
" - project_id is optional; link to a project when the user mentions one\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||
" did not explicitly request; 0 otherwise\n"
|
||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||
" - Always confirm the action in plain, user-friendly language."
|
||||
)
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -220,35 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@registry.register
|
||||
class TaskAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "task_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [
|
||||
list_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = get_llm()
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
|
||||
@@ -1,30 +1,13 @@
|
||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
||||
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||
"track progress on a project — they are not calendar events.\n\n"
|
||||
"Rules:\n"
|
||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||
" - Listing without a project_id returns all timelines across projects\n"
|
||||
" - Always echo the title and formatted date in your confirmation."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(project_id: str = "") -> str:
|
||||
@@ -106,23 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
|
||||
return f"Timeline {timeline_id} deleted."
|
||||
|
||||
|
||||
@registry.register
|
||||
class TimelineAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "timeline_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages project timelines (milestones): list, create, update, delete"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [list_timelines, create_timeline, update_timeline, delete_timeline]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = get_llm()
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
|
||||
@@ -9,8 +9,10 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.orchestrator import orchestrate
|
||||
from app.schemas import ChatRequest, UserProfile
|
||||
from app.core.deep_agent import run_home
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
@@ -20,10 +22,21 @@ async def chat(
|
||||
body: ChatRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> JSONResponse:
|
||||
"""Route a chat message through the orchestrator.
|
||||
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||
|
||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
||||
"""
|
||||
result = await orchestrate(body)
|
||||
context = {
|
||||
**body.context.model_dump(),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
response_text = await run_home(
|
||||
user_id=current_user.id,
|
||||
message=body.message,
|
||||
context=context,
|
||||
db_session_factory=async_session,
|
||||
)
|
||||
result = ChatResponse(response=response_text)
|
||||
return JSONResponse(content=result.model_dump())
|
||||
|
||||
@@ -43,7 +43,7 @@ from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.orchestrator import orchestrate_v3_stream
|
||||
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
from app.db import async_session
|
||||
@@ -204,9 +204,17 @@ async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||
async def _executor(payload: dict) -> dict:
|
||||
payload["type"] = WsFrameType.tool_call
|
||||
call_id = payload["id"]
|
||||
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||
await websocket.send_text(json.dumps(payload))
|
||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||
return await future
|
||||
future = device_manager.create_pending_call(user_id, call_id)
|
||||
result = await future
|
||||
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||
call_id, type(result).__name__,
|
||||
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||
if result is None:
|
||||
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||
return result
|
||||
return _executor
|
||||
|
||||
|
||||
@@ -233,21 +241,13 @@ async def _handle_home_request(
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
agent_holder: list = []
|
||||
try:
|
||||
token_stream = orchestrate_v3_stream(
|
||||
user_id, message, context, agent_holder=agent_holder
|
||||
event_stream = run_home_stream(
|
||||
user_id, message, context, db_session_factory=async_session
|
||||
)
|
||||
formatter = HomeFormatter(request_id=request_id, tool_results=[])
|
||||
async for ws_frame in formatter.format(token_stream):
|
||||
# Inject mutations from agent tool_results into stream_end
|
||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
||||
for r in getattr(agent_holder[0], "tool_results", [])
|
||||
]
|
||||
formatter = HomeFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# Collect text chunks to build the full response for episode storage
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
@@ -287,18 +287,13 @@ async def _handle_floating_request(
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
agent_holder: list = []
|
||||
try:
|
||||
token_stream = orchestrate_v3_stream(
|
||||
user_id, message, context, agent_holder=agent_holder
|
||||
event_stream = run_floating_stream(
|
||||
user_id, message, context, scope=scope,
|
||||
db_session_factory=async_session,
|
||||
)
|
||||
formatter = FloatingFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(token_stream):
|
||||
if ws_frame.type == "stream_end" and agent_holder: # type: ignore[union-attr]
|
||||
ws_frame.mutations = [ # type: ignore[union-attr]
|
||||
{"action": r["action"], "table": r["table"], "data": r["data"]}
|
||||
for r in getattr(agent_holder[0], "tool_results", [])
|
||||
]
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.execution_plan import plan_cache
|
||||
from app.schemas import ExecutionPlan, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
||||
|
||||
|
||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
||||
async def list_playbooks(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[ExecutionPlan]:
|
||||
"""Return all cached execution plan playbooks for the authenticated user.
|
||||
|
||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
||||
"""
|
||||
return plan_cache.get_all_playbooks()
|
||||
|
||||
|
||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
||||
async def get_playbook(
|
||||
plan_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> ExecutionPlan:
|
||||
"""Return a specific execution plan playbook by ID."""
|
||||
plan = plan_cache.get_plan(plan_id)
|
||||
if plan is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Plan not found: {plan_id}",
|
||||
)
|
||||
return plan
|
||||
@@ -1,217 +0,0 @@
|
||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for all agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "",
|
||||
shared_memory: dict[str, Any] | None = None,
|
||||
vector_store_context: list[str] | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||
self.vector_store_context: list[str] = vector_store_context or []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
"""Override in subclasses to advertise capabilities."""
|
||||
return []
|
||||
|
||||
|
||||
class ChatAgent(BaseAgent):
|
||||
"""Base class for LLM-powered chat agents."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
# Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results.
|
||||
self.tool_results: list[dict] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
"""Process a user query and return a text response."""
|
||||
...
|
||||
|
||||
async def handle_stream(
|
||||
self, query: str, context: dict[str, Any]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming variant of handle().
|
||||
|
||||
Default: calls handle() and yields the full response as one chunk.
|
||||
Override in subclasses for true token-level streaming via _tool_loop_stream.
|
||||
"""
|
||||
yield await self.handle(query, context)
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Any]:
|
||||
"""Return LangChain tool definitions available to this agent."""
|
||||
...
|
||||
|
||||
async def _tool_loop(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> str:
|
||||
"""Shared tool-calling loop.
|
||||
|
||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||
requesting tool calls or *max_iter* is reached, and returns the
|
||||
final text response. Captures raw execute_on_client results in
|
||||
``self.tool_results``.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return str(response.content)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — ask model for a final answer without tools
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
self.tool_results = collector
|
||||
|
||||
async def _tool_loop_stream(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming variant of ``_tool_loop``.
|
||||
|
||||
Behaves identically for tool-calling iterations (uses ainvoke to parse
|
||||
tool calls). For the final response — when the model produces no further
|
||||
tool calls — switches to ``llm.astream()`` and yields text tokens.
|
||||
Captures raw execute_on_client results in ``self.tool_results``.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
|
||||
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
|
||||
if not response.tool_calls:
|
||||
# Stream the final answer — don't keep the ainvoke result.
|
||||
async for chunk in llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield str(chunk.content)
|
||||
return
|
||||
|
||||
messages.append(response)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — stream a final answer without tools
|
||||
async for chunk in llm.astream(messages):
|
||||
if chunk.content:
|
||||
yield str(chunk.content)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
self.tool_results = collector
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Singleton registry for ChatAgent subclasses."""
|
||||
|
||||
_instance: AgentRegistry | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._agents: dict[str, type[ChatAgent]] = {}
|
||||
|
||||
def __new__(cls) -> AgentRegistry:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._agents = {}
|
||||
return cls._instance
|
||||
|
||||
# ── public API ───────────────────────────────────────────────────
|
||||
|
||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
||||
"""Class decorator — registers an agent by its name."""
|
||||
instance = agent_class()
|
||||
name = instance.get_name()
|
||||
self._agents[name] = agent_class
|
||||
return agent_class
|
||||
|
||||
def get(self, name: str) -> ChatAgent:
|
||||
"""Return a fresh instance of the named agent."""
|
||||
cls = self._agents.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"Agent not found: {name}")
|
||||
return cls()
|
||||
|
||||
def list_agents(self) -> list[dict[str, str]]:
|
||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
||||
result: list[dict[str, str]] = []
|
||||
for cls in self._agents.values():
|
||||
inst = cls()
|
||||
result.append(
|
||||
{"name": inst.get_name(), "description": inst.get_description()}
|
||||
)
|
||||
return result
|
||||
|
||||
async def call_agent(
|
||||
self, name: str, query: str, context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Instantiate the named agent and call its ``handle`` method."""
|
||||
agent = self.get(name)
|
||||
return await agent.handle(query, context)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = AgentRegistry()
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Agent run orchestrator.
|
||||
"""Agent run manager.
|
||||
|
||||
Drives two agent types:
|
||||
|
||||
|
||||
429
app/core/deep_agent.py
Normal file
429
app/core/deep_agent.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Deep Agent — LangGraph hierarchical supervisors for home and floating modes.
|
||||
|
||||
Two supervisor graphs (both ``create_react_agent``):
|
||||
* **HomeSupervisor** — gathers data from multiple domains, presents
|
||||
structured overview with tool-result blocks.
|
||||
* **FloatingSupervisor** — focused, scoped assistant for a single entity/domain.
|
||||
|
||||
Each supervisor delegates to four sub-agent tools, each a compiled
|
||||
``create_react_agent`` wrapping the domain CRUD tools (task, project, note,
|
||||
timeline). The sub-agents talk to Electron via ``execute_on_client``.
|
||||
|
||||
Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that
|
||||
callers can sniff:
|
||||
* ``("messages", (token, metadata))`` — text tokens for streaming
|
||||
* ``("updates", ...)`` — tool call results for mutations
|
||||
|
||||
An ``update_core_memory`` tool is available to both supervisors for
|
||||
persisting user preferences mid-conversation (MemGPT-style).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from app.core.llm import get_llm
|
||||
from app.core.ws_context import (
|
||||
clear_tool_result_collector,
|
||||
set_tool_result_collector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Sub-agent tool imports ────────────────────────────────────────────
|
||||
|
||||
from app.agents.task_agent import ( # noqa: E402
|
||||
add_task_comment,
|
||||
create_task,
|
||||
delete_task,
|
||||
delete_task_comment,
|
||||
list_task_comments,
|
||||
list_tasks,
|
||||
list_tasks_due_today,
|
||||
update_task,
|
||||
)
|
||||
from app.agents.note_agent import ( # noqa: E402
|
||||
create_note,
|
||||
delete_note,
|
||||
get_note,
|
||||
list_notes,
|
||||
update_note,
|
||||
)
|
||||
from app.agents.project_agent import ( # noqa: E402
|
||||
create_project,
|
||||
delete_project,
|
||||
get_project,
|
||||
list_all_projects,
|
||||
list_projects,
|
||||
update_project,
|
||||
)
|
||||
from app.agents.timeline_agent import ( # noqa: E402
|
||||
create_timeline,
|
||||
delete_timeline,
|
||||
list_timelines,
|
||||
update_timeline,
|
||||
)
|
||||
|
||||
# ── Sub-agent definitions ─────────────────────────────────────────────
|
||||
|
||||
_TASK_TOOLS = [
|
||||
list_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
|
||||
_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note]
|
||||
|
||||
_PROJECT_TOOLS = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
|
||||
_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline]
|
||||
|
||||
|
||||
def _build_subagent_tool(
|
||||
name: str,
|
||||
description: str,
|
||||
system_prompt: str,
|
||||
tools: list,
|
||||
):
|
||||
"""Build a compiled sub-agent graph and wrap it as a LangChain tool."""
|
||||
subgraph = create_react_agent(
|
||||
model=get_llm(),
|
||||
tools=tools,
|
||||
prompt=system_prompt,
|
||||
name=name,
|
||||
)
|
||||
|
||||
@tool(name=name, description=description)
|
||||
async def _run(query: str) -> str:
|
||||
result = await subgraph.ainvoke(
|
||||
{"messages": [HumanMessage(content=query)]}
|
||||
)
|
||||
messages = result["messages"]
|
||||
# Return the last AI message content
|
||||
for msg in reversed(messages):
|
||||
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
|
||||
return str(msg.content)
|
||||
return "No response from sub-agent."
|
||||
|
||||
return _run
|
||||
|
||||
|
||||
def _make_subagent_tools() -> list:
|
||||
"""Create the four sub-agent tools for the supervisor."""
|
||||
return [
|
||||
_build_subagent_tool(
|
||||
name="task_agent",
|
||||
description=(
|
||||
"Manages tasks and comments: list, create, update, delete, "
|
||||
"due-today, comments. Delegate task-related queries here."
|
||||
),
|
||||
system_prompt=(
|
||||
"You are a task management assistant. You create, update, list, "
|
||||
"and track tasks and their comments.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: todo, in_progress, done\n"
|
||||
" - priority must be one of: high, medium, low\n"
|
||||
" - due_date is a Unix timestamp in milliseconds\n"
|
||||
" - assignees is a JSON-encoded array of strings\n"
|
||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||
" - Always confirm the action in plain, user-friendly language."
|
||||
),
|
||||
tools=_TASK_TOOLS,
|
||||
),
|
||||
_build_subagent_tool(
|
||||
name="note_agent",
|
||||
description=(
|
||||
"Manages notes: list, get, create, update, delete. "
|
||||
"Delegate note-related queries here."
|
||||
),
|
||||
system_prompt=(
|
||||
"You are a note-taking assistant. You help users create, retrieve, "
|
||||
"update, and delete Markdown notes in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - content is always Markdown; preserve formatting when updating\n"
|
||||
" - When updating, call get_note first if you need to read existing "
|
||||
"content before appending or replacing sections\n"
|
||||
" - Do not fabricate note content."
|
||||
),
|
||||
tools=_NOTE_TOOLS,
|
||||
),
|
||||
_build_subagent_tool(
|
||||
name="project_agent",
|
||||
description=(
|
||||
"Manages projects: list, get, create, update, archive, delete. "
|
||||
"Delegate project-related queries here."
|
||||
),
|
||||
system_prompt=(
|
||||
"You are a project management assistant. You help users create, "
|
||||
"find, update, and archive projects.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: active, archived\n"
|
||||
" - Prefer archiving over deletion\n"
|
||||
" - ai_summary is populated only when the user asks for a summary."
|
||||
),
|
||||
tools=_PROJECT_TOOLS,
|
||||
),
|
||||
_build_subagent_tool(
|
||||
name="timeline_agent",
|
||||
description=(
|
||||
"Manages project timelines (milestones): list, create, update, "
|
||||
"delete. Delegate timeline/milestone queries here."
|
||||
),
|
||||
system_prompt=(
|
||||
"You are a project timeline assistant. Timelines are milestone "
|
||||
"dates that track progress on a project.\n\n"
|
||||
"Rules:\n"
|
||||
" - project_id is REQUIRED for every create\n"
|
||||
" - date is a Unix timestamp in milliseconds\n"
|
||||
" - For update_timeline, use -1 for integer fields you do not "
|
||||
"want to change."
|
||||
),
|
||||
tools=_TIMELINE_TOOLS,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── Update core memory tool ──────────────────────────────────────────
|
||||
|
||||
def _make_update_core_memory_tool(user_id: str, db_session_factory):
|
||||
"""Create a tool that persists a key/value preference in core memory."""
|
||||
|
||||
@tool
|
||||
async def update_core_memory(key: str, value: str) -> str:
|
||||
"""Save a user preference or fact to long-term core memory.
|
||||
key: short label for the memory (e.g. 'preferred_language', 'timezone')
|
||||
value: the value to remember
|
||||
Use this when the user states a preference or fact worth remembering.
|
||||
"""
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
|
||||
async with db_session_factory() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, key, value)
|
||||
return f"Remembered: {key} = {value}"
|
||||
|
||||
return update_core_memory
|
||||
|
||||
|
||||
# ── System prompts ────────────────────────────────────────────────────
|
||||
|
||||
_HOME_SYSTEM = (
|
||||
"You are Adiuva, a smart workspace assistant on the Home dashboard.\n"
|
||||
"Your job is to help the user by gathering data from their workspace and "
|
||||
"presenting a comprehensive overview.\n\n"
|
||||
"You have sub-agent tools (task_agent, note_agent, project_agent, "
|
||||
"timeline_agent) that can query and mutate workspace data. Delegate to "
|
||||
"the appropriate sub-agent(s) based on the user's request. You can call "
|
||||
"multiple sub-agents if needed.\n\n"
|
||||
"You also have an update_core_memory tool — use it when the user states "
|
||||
"a preference or important fact worth remembering long-term.\n\n"
|
||||
"After gathering data, synthesize a clear, helpful response for the user.\n\n"
|
||||
"Memory context:\n{memory_context}"
|
||||
)
|
||||
|
||||
_FLOATING_SYSTEM = (
|
||||
"You are Adiuva, a focused workspace assistant in the floating panel.\n"
|
||||
"The user is currently working in the '{scope_type}' section"
|
||||
"{scope_detail}.\n\n"
|
||||
"You have sub-agent tools (task_agent, note_agent, project_agent, "
|
||||
"timeline_agent) that can query and mutate workspace data. Focus your "
|
||||
"help on the user's current scope, but you can use other sub-agents "
|
||||
"if the request requires it.\n\n"
|
||||
"You also have an update_core_memory tool — use it when the user states "
|
||||
"a preference or important fact worth remembering long-term.\n\n"
|
||||
"Provide direct, conversational responses.\n\n"
|
||||
"Memory context:\n{memory_context}"
|
||||
)
|
||||
|
||||
|
||||
def _format_memory_context(memory: dict[str, Any]) -> str:
|
||||
"""Format the memory dict into a readable string for the system prompt."""
|
||||
if not memory:
|
||||
return "(no memory available)"
|
||||
parts = []
|
||||
if memory.get("core_memory"):
|
||||
parts.append("Preferences: " + json.dumps(memory["core_memory"]))
|
||||
if memory.get("associative_memory"):
|
||||
parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3]))
|
||||
if memory.get("episodic_memory"):
|
||||
parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3]))
|
||||
if memory.get("proactive_hints"):
|
||||
parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3]))
|
||||
return "\n".join(parts) if parts else "(no memory available)"
|
||||
|
||||
|
||||
# ── Graph builders ────────────────────────────────────────────────────
|
||||
|
||||
def build_home_graph(
|
||||
user_id: str,
|
||||
memory_context: dict[str, Any],
|
||||
db_session_factory,
|
||||
):
|
||||
"""Build the Home supervisor graph."""
|
||||
subagent_tools = _make_subagent_tools()
|
||||
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||
all_tools = subagent_tools + [memory_tool]
|
||||
|
||||
prompt = _HOME_SYSTEM.format(
|
||||
memory_context=_format_memory_context(memory_context),
|
||||
)
|
||||
|
||||
return create_react_agent(
|
||||
model=get_llm(),
|
||||
tools=all_tools,
|
||||
prompt=prompt,
|
||||
name="home_supervisor",
|
||||
)
|
||||
|
||||
|
||||
def build_floating_graph(
|
||||
user_id: str,
|
||||
memory_context: dict[str, Any],
|
||||
scope: dict[str, Any],
|
||||
db_session_factory,
|
||||
):
|
||||
"""Build the Floating supervisor graph."""
|
||||
subagent_tools = _make_subagent_tools()
|
||||
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||
all_tools = subagent_tools + [memory_tool]
|
||||
|
||||
scope_type = scope.get("type", "general")
|
||||
scope_id = scope.get("id")
|
||||
scope_detail = f" (id: {scope_id})" if scope_id else ""
|
||||
|
||||
prompt = _FLOATING_SYSTEM.format(
|
||||
scope_type=scope_type,
|
||||
scope_detail=scope_detail,
|
||||
memory_context=_format_memory_context(memory_context),
|
||||
)
|
||||
|
||||
return create_react_agent(
|
||||
model=get_llm(),
|
||||
tools=all_tools,
|
||||
prompt=prompt,
|
||||
name="floating_supervisor",
|
||||
)
|
||||
|
||||
|
||||
# ── Stream event type ────────────────────────────────────────────────
|
||||
|
||||
# Events yielded by run_*_stream:
|
||||
# ("token", str) — text token for streaming
|
||||
# ("tool_start", dict) — {"name": "task_agent", "args": {...}}
|
||||
# ("tool_end", dict) — {"name": "task_agent", "result": "..."}
|
||||
|
||||
|
||||
# ── Stream runners ────────────────────────────────────────────────────
|
||||
|
||||
async def _run_graph_stream(
|
||||
graph,
|
||||
message: str,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Run a supervisor graph with streaming, yielding event tuples.
|
||||
|
||||
Uses ``stream_mode=["messages", "updates"]`` to get both token-level
|
||||
streaming and update events for tool calls.
|
||||
"""
|
||||
inputs = {"messages": [HumanMessage(content=message)]}
|
||||
|
||||
collector: list[dict] = []
|
||||
set_tool_result_collector(collector)
|
||||
try:
|
||||
async for stream_mode, chunk in graph.astream(
|
||||
inputs,
|
||||
stream_mode=["messages", "updates"],
|
||||
):
|
||||
if stream_mode == "messages":
|
||||
msg, metadata = chunk
|
||||
# Only yield tokens from the supervisor's final response
|
||||
# (not from sub-agent internal LLM calls)
|
||||
if (
|
||||
isinstance(msg, AIMessageChunk)
|
||||
and msg.content
|
||||
and not msg.tool_calls
|
||||
and metadata.get("langgraph_node") == "agent"
|
||||
):
|
||||
yield ("token", str(msg.content))
|
||||
|
||||
elif stream_mode == "updates":
|
||||
# Updates is a dict of {node_name: state_update}
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
for node_name, state_update in chunk.items():
|
||||
if node_name != "tools":
|
||||
continue
|
||||
# Tool node executed — extract tool call results
|
||||
tool_messages = state_update.get("messages", [])
|
||||
for tool_msg in tool_messages:
|
||||
if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"):
|
||||
yield (
|
||||
"tool_end",
|
||||
{"name": tool_msg.name, "result": str(tool_msg.content)},
|
||||
)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
|
||||
# Yield the collected mutations so callers can attach them to stream_end
|
||||
yield ("mutations", collector)
|
||||
|
||||
|
||||
async def run_home_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
db_session_factory,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Run the Home supervisor and yield streaming events."""
|
||||
graph = build_home_graph(user_id, context, db_session_factory)
|
||||
async for event in _run_graph_stream(graph, message):
|
||||
yield event
|
||||
|
||||
|
||||
async def run_floating_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
scope: dict[str, Any],
|
||||
db_session_factory,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Run the Floating supervisor and yield streaming events."""
|
||||
graph = build_floating_graph(user_id, context, scope, db_session_factory)
|
||||
async for event in _run_graph_stream(graph, message):
|
||||
yield event
|
||||
|
||||
|
||||
async def run_home(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
db_session_factory,
|
||||
) -> str:
|
||||
"""Run the Home supervisor (non-streaming) and return full response text."""
|
||||
graph = build_home_graph(user_id, context, db_session_factory)
|
||||
result = await graph.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}
|
||||
)
|
||||
messages = result["messages"]
|
||||
for msg in reversed(messages):
|
||||
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
|
||||
return str(msg.content)
|
||||
return ""
|
||||
@@ -1,222 +0,0 @@
|
||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import ExecutionPlan, PlanStep
|
||||
|
||||
|
||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
||||
|
||||
|
||||
class PromptTemplateRegistry:
|
||||
"""Server-side store mapping template IDs to prompt text.
|
||||
|
||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
||||
out of API responses.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._templates: dict[str, str] = {}
|
||||
|
||||
def register(self, template_id: str, prompt_text: str) -> None:
|
||||
self._templates[template_id] = prompt_text
|
||||
|
||||
def get(self, template_id: str) -> str:
|
||||
"""Resolve a template ID to its prompt text.
|
||||
|
||||
Raises ``KeyError`` if the template is not registered.
|
||||
"""
|
||||
text = self._templates.get(template_id)
|
||||
if text is None:
|
||||
raise KeyError(f"Template not found: {template_id!r}")
|
||||
return text
|
||||
|
||||
def has(self, template_id: str) -> bool:
|
||||
return template_id in self._templates
|
||||
|
||||
def list_ids(self) -> list[str]:
|
||||
"""Return all registered template IDs (never the text)."""
|
||||
return list(self._templates.keys())
|
||||
|
||||
|
||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
||||
|
||||
|
||||
class ExecutionPlanBuilder:
|
||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
||||
|
||||
Example::
|
||||
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, agent: str) -> None:
|
||||
self._agent = agent
|
||||
self._steps: list[PlanStep] = []
|
||||
|
||||
# ── step adders ──────────────────────────────────────────────────
|
||||
|
||||
def add_step(
|
||||
self, action: str, params: dict[str, Any] | None = None
|
||||
) -> ExecutionPlanBuilder:
|
||||
"""Append a generic action step with optional parameters."""
|
||||
self._steps.append(PlanStep(action=action, variables=params))
|
||||
return self
|
||||
|
||||
def add_llm_step(
|
||||
self, template_id: str, variables: dict[str, Any] | None = None
|
||||
) -> ExecutionPlanBuilder:
|
||||
"""Append an LLM step referencing a server-side template by ID."""
|
||||
self._steps.append(
|
||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
||||
)
|
||||
return self
|
||||
|
||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
||||
"""Append a step whose input comes from the output of an earlier step."""
|
||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
||||
return self
|
||||
|
||||
# ── build ────────────────────────────────────────────────────────
|
||||
|
||||
def build(self) -> ExecutionPlan:
|
||||
"""Validate step references and return the ``ExecutionPlan``.
|
||||
|
||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
||||
non-existent or future step index.
|
||||
"""
|
||||
for i, step in enumerate(self._steps):
|
||||
if step.data_from_step is not None:
|
||||
if not (0 <= step.data_from_step < i):
|
||||
raise ValueError(
|
||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
||||
f"reference a preceding step index in range 0..{i - 1}"
|
||||
)
|
||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
||||
|
||||
|
||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PlanCache:
|
||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
||||
|
||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
||||
The cache also serves as a runtime memoisation layer so that repeated
|
||||
identical intent classifications can skip re-building the plan.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 1000) -> None:
|
||||
self._maxsize = maxsize
|
||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
||||
|
||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
||||
if key in self._cache:
|
||||
del self._cache[key] # remove so re-insertion places it at the end
|
||||
elif len(self._cache) >= self._maxsize:
|
||||
self._cache.popitem(last=False) # evict least-recently-used
|
||||
self._cache[key] = plan
|
||||
|
||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
||||
|
||||
Accessing a plan marks it as most-recently used.
|
||||
"""
|
||||
if key not in self._cache:
|
||||
return None
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
|
||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
||||
"""Return all cached plans (most-recently used last)."""
|
||||
return list(self._cache.values())
|
||||
|
||||
|
||||
# ── Module-level singletons ───────────────────────────────────────────
|
||||
|
||||
template_registry = PromptTemplateRegistry()
|
||||
plan_cache = PlanCache()
|
||||
|
||||
|
||||
def _register_builtin_templates() -> None:
|
||||
"""Register the built-in server-side prompt templates.
|
||||
|
||||
These strings never leave the server. Clients only receive the IDs.
|
||||
"""
|
||||
_tpls: dict[str, str] = {
|
||||
"tpl_task_agent_default": (
|
||||
"You are a task management assistant. Help the user create, update, "
|
||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
||||
"done) and priority values (high, medium, low) from the workspace model."
|
||||
),
|
||||
"tpl_timeline_agent_default": (
|
||||
"You are a project timeline assistant. Help the user create and manage "
|
||||
"milestone timelines on their projects. Every timeline requires a "
|
||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
||||
),
|
||||
"tpl_project_agent_default": (
|
||||
"You are a project management assistant. Help the user create, find, "
|
||||
"update, and archive projects. Projects have a name, an optional client, "
|
||||
"and a status of either active or archived."
|
||||
),
|
||||
"tpl_note_agent_default": (
|
||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
||||
),
|
||||
"tpl_task_extract_from_project": (
|
||||
"Extract all actionable tasks from the provided project context. "
|
||||
"Return a structured list of tasks, each with a title, inferred priority "
|
||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
||||
"milliseconds where a deadline can be inferred."
|
||||
),
|
||||
"tpl_note_weekly_summary": (
|
||||
"Generate a weekly project summary note from the provided workspace data. "
|
||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
||||
"and upcoming timelines. Format the output as clean Markdown."
|
||||
),
|
||||
}
|
||||
for tid, text in _tpls.items():
|
||||
template_registry.register(tid, text)
|
||||
|
||||
|
||||
def _load_playbooks() -> None:
|
||||
"""Pre-build and cache the built-in playbooks."""
|
||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
||||
(
|
||||
"create_tasks_from_project",
|
||||
ExecutionPlanBuilder("project_agent")
|
||||
.add_llm_step(
|
||||
"tpl_task_extract_from_project",
|
||||
{"source": "project_context"},
|
||||
)
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build(),
|
||||
),
|
||||
(
|
||||
"generate_weekly_note",
|
||||
ExecutionPlanBuilder("note_agent")
|
||||
.add_llm_step(
|
||||
"tpl_note_weekly_summary",
|
||||
{"period": "last_7_days"},
|
||||
)
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build(),
|
||||
),
|
||||
]
|
||||
for key, plan in playbooks:
|
||||
plan_cache.cache_plan(key, plan)
|
||||
|
||||
|
||||
# Initialise on module load
|
||||
_register_builtin_templates()
|
||||
_load_playbooks()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
||||
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||
instead of directly constructing a provider-specific class. The model string
|
||||
follows the `LiteLLM model naming convention
|
||||
<https://docs.litellm.ai/docs/providers>`_:
|
||||
|
||||
@@ -43,7 +43,7 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||
|
||||
|
||||
class MemoryMiddleware:
|
||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||
"""Enrich agent context with memory and persist interactions after."""
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self._db = db
|
||||
@@ -51,7 +51,7 @@ class MemoryMiddleware:
|
||||
# ── Public API ────────────────────────────────────────────────────────────
|
||||
|
||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||
"""Build memory context dict to inject into the agent before LLM call.
|
||||
|
||||
Returns a dict with keys:
|
||||
core_memory — {key: plaintext_value, ...}
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
from app.core.llm import get_router_llm
|
||||
from app.core.agent_registry import registry as _default_registry
|
||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||
|
||||
_FALLBACK_AGENT = "task_agent"
|
||||
|
||||
_CLASSIFY_SYSTEM = (
|
||||
"You are an intent classifier. Given the user message and context, decide "
|
||||
"which agent to route to.\n"
|
||||
"Available agents: {agents}\n"
|
||||
"Respond with just the agent name, nothing else."
|
||||
)
|
||||
|
||||
_SYNTHESIZE_HUMAN = (
|
||||
"Combine the following agent results into one coherent response.\n\n"
|
||||
"Agent results:\n{results}\n\n"
|
||||
"Original message: {message}"
|
||||
)
|
||||
|
||||
|
||||
def _make_llm():
|
||||
return get_router_llm()
|
||||
|
||||
|
||||
async def classify_intent(
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> str:
|
||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
||||
|
||||
Falls back to ``task_agent`` when the registry is empty or the model
|
||||
returns a name that is not registered.
|
||||
"""
|
||||
agents = reg.list_agents()
|
||||
if not agents:
|
||||
return _FALLBACK_AGENT
|
||||
|
||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
||||
# Truncate context to keep the classification prompt short
|
||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
||||
|
||||
llm = _make_llm()
|
||||
response = await llm.ainvoke(
|
||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
||||
)
|
||||
|
||||
agent_name = str(response.content).strip().lower()
|
||||
known = {a["name"] for a in agents}
|
||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
||||
|
||||
|
||||
async def route_single(
|
||||
agent_name: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> ChatResponse:
|
||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
||||
response_text = await reg.call_agent(agent_name, message, context)
|
||||
return ChatResponse(response=response_text)
|
||||
|
||||
|
||||
async def route_pipeline(
|
||||
agent_names: list[str],
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> ChatResponse:
|
||||
"""Execute agents sequentially; each agent receives previous results in context.
|
||||
|
||||
A final LLM synthesis call merges all results into one coherent response.
|
||||
"""
|
||||
previous_results: list[str] = []
|
||||
|
||||
for agent_name in agent_names:
|
||||
ctx = {**context, "previous_results": list(previous_results)}
|
||||
result = await reg.call_agent(agent_name, message, ctx)
|
||||
previous_results.append(result)
|
||||
|
||||
results_str = "\n\n".join(
|
||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
||||
)
|
||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
||||
llm = _make_llm()
|
||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
||||
return ChatResponse(response=str(synthesis.content))
|
||||
|
||||
|
||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
||||
|
||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
||||
If a default template exists for the agent, an LLM step is emitted;
|
||||
otherwise a plain ``handle`` action step is used.
|
||||
"""
|
||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
||||
|
||||
template_id = f"tpl_{agent_name}_default"
|
||||
builder = ExecutionPlanBuilder(agent_name)
|
||||
if template_registry.has(template_id):
|
||||
builder.add_llm_step(template_id, {"message": message})
|
||||
else:
|
||||
builder.add_step("handle", {"message": message})
|
||||
return builder.build()
|
||||
|
||||
|
||||
async def orchestrate(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> ChatResponse | ExecutionPlan:
|
||||
"""Main orchestration entry point.
|
||||
|
||||
* Classifies the user's intent to select an agent.
|
||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
||||
``ChatResponse``.
|
||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
|
||||
context = request.context.model_dump()
|
||||
agent_name = await classify_intent(request.message, context, reg)
|
||||
|
||||
if request.execution_mode == "direct":
|
||||
return await route_single(agent_name, request.message, context, reg)
|
||||
|
||||
# plan mode — return plan, do not execute
|
||||
return _build_plan(agent_name, request.message)
|
||||
|
||||
|
||||
async def orchestrate_v3(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> tuple[str, ChatAgent]:
|
||||
"""v3 orchestration — returns (agent_name, agent_instance); caller drives execution.
|
||||
|
||||
Classifies intent and instantiates the matching agent. The caller is responsible
|
||||
for invoking handle(), handle_stream(), or _tool_loop_stream() as needed.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
agent_name = await classify_intent(message, context, reg)
|
||||
return agent_name, reg.get(agent_name)
|
||||
|
||||
|
||||
async def orchestrate_v3_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry | None = None,
|
||||
agent_holder: list | None = None,
|
||||
) -> AsyncGenerator[tuple[str, str], None]:
|
||||
"""v3 streaming orchestration — yields (agent_name, token) pairs.
|
||||
|
||||
The first yield always carries the agent_name with an empty token so that
|
||||
callers (e.g. FloatingFormatter) can detect the routing domain before any text
|
||||
tokens arrive.
|
||||
|
||||
If *agent_holder* is provided (a list), the agent instance is appended so
|
||||
callers can access ``agent.tool_results`` after the stream completes.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
agent_name = await classify_intent(message, context, reg)
|
||||
agent = reg.get(agent_name)
|
||||
if agent_holder is not None:
|
||||
agent_holder.append(agent)
|
||||
yield agent_name, "" # domain signal — no token yet
|
||||
async for token in agent.handle_stream(message, context):
|
||||
yield agent_name, token
|
||||
|
||||
|
||||
async def orchestrate_stream(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming orchestration — yields plain text chunks only.
|
||||
|
||||
The WebSocket handler in ``app/api/routes/chat.py`` is responsible for
|
||||
wrapping each chunk in a ``text_chunk`` frame and sending the final
|
||||
``final`` frame once the generator is exhausted.
|
||||
|
||||
Agents do not yet support token-level streaming; the full response is
|
||||
fetched first (which may involve multiple WS round-trips for tool calls),
|
||||
then emitted in fixed-size chunks.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
|
||||
context = request.context.model_dump()
|
||||
agent_name = await classify_intent(request.message, context, reg)
|
||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
||||
|
||||
chunk_size = 50
|
||||
for i in range(0, len(response_text), chunk_size):
|
||||
yield response_text[i : i + chunk_size]
|
||||
@@ -1,12 +1,23 @@
|
||||
"""Output Formatter — transforms orchestrator token streams into WS frame sequences.
|
||||
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||
|
||||
HomeFormatter: produces stream_start, stream_text / stream_block, stream_end
|
||||
FloatingFormatter: produces floating_domain, stream_text, stream_end
|
||||
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||
* ``("token", str)`` — supervisor text token
|
||||
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||
|
||||
HomeFormatter:
|
||||
* Sniffs ``tool_end`` events → emits ``WsStreamBlock`` (entity_ref with raw data)
|
||||
* Streams text tokens → emits ``WsStreamText``
|
||||
* Attaches mutations → injects into ``WsStreamEnd``
|
||||
|
||||
FloatingFormatter:
|
||||
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||
* Streams text tokens → emits ``WsStreamText``
|
||||
* Attaches mutations → injects into ``WsStreamEnd``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
@@ -21,10 +32,7 @@ from app.schemas import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid chart types (matching shadcn/ui Recharts wrappers in Electron)
|
||||
_VALID_CHART_TYPES = {"area", "bar", "line", "pie", "radar", "radial"}
|
||||
|
||||
# Map agent name → floating domain
|
||||
# Map sub-agent tool name → floating domain / entity type
|
||||
_AGENT_DOMAIN: dict[str, str] = {
|
||||
"task_agent": "tasks",
|
||||
"timeline_agent": "timelines",
|
||||
@@ -36,180 +44,74 @@ WsFrame = WsStreamStart | WsStreamText | WsStreamBlock | WsStreamEnd | WsFloatin
|
||||
|
||||
|
||||
class HomeFormatter:
|
||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||
|
||||
The LLM is expected to output a newline-delimited sequence of JSON objects,
|
||||
each with a ``type`` field:
|
||||
- ``text`` → yields WsStreamText immediately (word-by-word)
|
||||
- ``chart`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
- ``entity_ref`` → resolves from tool_results, yields WsStreamBlock
|
||||
- ``table`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
- ``timeline`` → buffers full JSON, validates, yields WsStreamBlock
|
||||
|
||||
Invalid or unknown blocks are logged and skipped — stream never crashes.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, tool_results: list[dict]) -> None:
|
||||
self.request_id = request_id
|
||||
self.tool_results = tool_results
|
||||
|
||||
async def format(
|
||||
self,
|
||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
|
||||
buffer = ""
|
||||
async for _agent_name, token in token_stream:
|
||||
if not token:
|
||||
continue
|
||||
buffer += token
|
||||
# Flush any complete JSON objects from the buffer
|
||||
async for frame in self._flush_complete_objects(buffer):
|
||||
buffer = "" # reset after flush
|
||||
yield frame
|
||||
break # only one flush per iteration; rest accumulates
|
||||
|
||||
# Flush any remaining content
|
||||
if buffer.strip():
|
||||
async for frame in self._flush_complete_objects(buffer, final=True):
|
||||
yield frame
|
||||
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
|
||||
async def _flush_complete_objects(
|
||||
self, text: str, final: bool = False
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
"""Try to parse and yield all complete JSON objects from *text*.
|
||||
|
||||
Yields nothing if text is incomplete JSON (unless *final* is True,
|
||||
in which case remaining text is emitted as plain stream_text).
|
||||
"""
|
||||
remaining = text.strip()
|
||||
while remaining:
|
||||
# Fast path: plain text (not JSON)
|
||||
if not remaining.startswith("{"):
|
||||
# Yield as plain text chunk
|
||||
newline_idx = remaining.find("\n")
|
||||
if newline_idx == -1:
|
||||
if final:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
else:
|
||||
return # accumulate more
|
||||
else:
|
||||
line = remaining[:newline_idx].strip()
|
||||
remaining = remaining[newline_idx + 1:].strip()
|
||||
if line:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=line)
|
||||
continue
|
||||
|
||||
# Try to decode a JSON object
|
||||
try:
|
||||
obj, end_idx = _try_parse_json(remaining)
|
||||
except ValueError:
|
||||
if final:
|
||||
# Emit as raw text if we can't parse
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
return
|
||||
|
||||
if obj is None:
|
||||
if final:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=remaining)
|
||||
remaining = ""
|
||||
return # incomplete — need more tokens
|
||||
|
||||
remaining = remaining[end_idx:].strip()
|
||||
block_type = obj.get("type")
|
||||
|
||||
frame = self._dispatch_block(obj, block_type)
|
||||
if frame is not None:
|
||||
yield frame
|
||||
|
||||
def _dispatch_block(self, obj: dict, block_type: str | None) -> WsFrame | None:
|
||||
if block_type == "text":
|
||||
content = obj.get("content", "")
|
||||
if content:
|
||||
return WsStreamText(request_id=self.request_id, chunk=str(content))
|
||||
return None
|
||||
|
||||
if block_type == "chart":
|
||||
chart_type = obj.get("chartType")
|
||||
if chart_type not in _VALID_CHART_TYPES:
|
||||
logger.warning("HomeFormatter: invalid chartType=%r — skipping", chart_type)
|
||||
return None
|
||||
if not isinstance(obj.get("data"), list):
|
||||
logger.warning("HomeFormatter: chart missing data array — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="chart",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
if block_type == "entity_ref":
|
||||
entity = obj.get("entity")
|
||||
resolved = self._resolve_entity(entity)
|
||||
if resolved is None:
|
||||
logger.warning("HomeFormatter: entity_ref %r not found in tool_results — skipping", entity)
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="entity_ref",
|
||||
data={"entity": entity, "items": resolved},
|
||||
)
|
||||
|
||||
if block_type == "table":
|
||||
if not isinstance(obj.get("headers"), list) or not isinstance(obj.get("rows"), list):
|
||||
logger.warning("HomeFormatter: table missing headers/rows — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="table",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
if block_type == "timeline":
|
||||
if not isinstance(obj.get("timelines"), list):
|
||||
logger.warning("HomeFormatter: timeline missing timelines — skipping")
|
||||
return None
|
||||
return WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="timeline",
|
||||
data=obj,
|
||||
)
|
||||
|
||||
logger.warning("HomeFormatter: unknown block type=%r — skipping", block_type)
|
||||
return None
|
||||
|
||||
def _resolve_entity(self, entity: str | None) -> list[dict] | None:
|
||||
"""Find matching items in tool_results by entity type."""
|
||||
if not entity:
|
||||
return None
|
||||
matches = [r for r in self.tool_results if r.get("entity") == entity]
|
||||
return matches if matches else None
|
||||
|
||||
|
||||
class FloatingFormatter:
|
||||
"""Parses a token stream from orchestrate_v3_stream and yields WS frames.
|
||||
|
||||
Emits floating_domain immediately (from agent_name), then streams all tokens
|
||||
as plain stream_text — no block parsing for floating context.
|
||||
``tool_end`` events from sub-agents are emitted as ``WsStreamBlock``
|
||||
(entity_ref) so the client can render structured data. Text tokens are
|
||||
forwarded as ``WsStreamText``. Mutations are attached to ``WsStreamEnd``.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._mutations: list[dict] = []
|
||||
|
||||
async def format(
|
||||
self,
|
||||
token_stream: AsyncGenerator[tuple[str, str], None],
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "token":
|
||||
if data:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||
|
||||
elif event_type == "tool_end":
|
||||
# Sub-agent finished — emit its result as an entity_ref block
|
||||
name = data.get("name", "")
|
||||
entity = _AGENT_DOMAIN.get(name)
|
||||
if entity:
|
||||
yield WsStreamBlock(
|
||||
request_id=self.request_id,
|
||||
block_type="entity_ref",
|
||||
data={"entity": entity, "result": data.get("result", "")},
|
||||
)
|
||||
|
||||
elif event_type == "mutations":
|
||||
self._mutations = data or []
|
||||
|
||||
yield WsStreamEnd(
|
||||
request_id=self.request_id,
|
||||
mutations=[
|
||||
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||
for m in self._mutations
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FloatingFormatter:
|
||||
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||
|
||||
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||
``WsStreamText``. No block parsing for floating context.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._mutations: list[dict] = []
|
||||
|
||||
async def format(
|
||||
self,
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
domain_sent = False
|
||||
|
||||
async for agent_name, token in token_stream:
|
||||
if not domain_sent:
|
||||
domain = _AGENT_DOMAIN.get(agent_name, "tasks")
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "tool_end" and not domain_sent:
|
||||
# Sniff domain from the first sub-agent that completes
|
||||
name = data.get("name", "")
|
||||
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain=domain, # type: ignore[arg-type]
|
||||
@@ -217,28 +119,33 @@ class FloatingFormatter:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
domain_sent = True
|
||||
|
||||
if token:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=token)
|
||||
elif event_type == "token":
|
||||
if not domain_sent:
|
||||
# First token arrived before any tool_end — default domain
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain="tasks", # type: ignore[arg-type]
|
||||
)
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
domain_sent = True
|
||||
if data:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
elif event_type == "mutations":
|
||||
self._mutations = data or []
|
||||
|
||||
# If no events triggered domain_sent (edge case), still emit structure
|
||||
if not domain_sent:
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain="tasks", # type: ignore[arg-type]
|
||||
)
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _try_parse_json(text: str) -> tuple[dict[str, Any] | None, int]:
|
||||
"""Attempt to parse the first complete JSON object from *text*.
|
||||
|
||||
Returns ``(parsed_dict, end_index)`` on success, ``(None, 0)`` when the
|
||||
object is incomplete, and raises ``ValueError`` when text is not JSON.
|
||||
"""
|
||||
decoder = json.JSONDecoder()
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(text)
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError("Expected JSON object")
|
||||
return obj, end_idx
|
||||
except json.JSONDecodeError as exc:
|
||||
# Incomplete JSON — need more tokens
|
||||
if "Unterminated" in str(exc) or exc.pos == len(text):
|
||||
return None, 0
|
||||
raise ValueError(str(exc)) from exc
|
||||
yield WsStreamEnd(
|
||||
request_id=self.request_id,
|
||||
mutations=[
|
||||
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||
for m in self._mutations
|
||||
],
|
||||
)
|
||||
|
||||
@@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Holds the execute callback for the current WS session.
|
||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||
"_client_executor"
|
||||
)
|
||||
|
||||
# Optional collector that captures raw execute_on_client results.
|
||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
@@ -81,7 +84,12 @@ async def execute_on_client(
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||
result = await callback(payload)
|
||||
if result is None:
|
||||
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||
else:
|
||||
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
|
||||
@@ -18,10 +18,7 @@ from app.config.settings import settings
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: initialise DB connection pool and agent registry
|
||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||
|
||||
# Startup: initialise DB connection pool
|
||||
yield
|
||||
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
@@ -51,11 +48,10 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plans, plugins, storage, vectors
|
||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(plans.router, prefix="/api/v1")
|
||||
app.include_router(storage.router, prefix="/api/v1")
|
||||
app.include_router(vectors.router, prefix="/api/v1")
|
||||
app.include_router(backup.router, prefix="/api/v1")
|
||||
|
||||
@@ -41,41 +41,13 @@ class ChatContext(BaseModel):
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PlanAction(BaseModel):
|
||||
type: Literal[
|
||||
"create_record",
|
||||
"update_record",
|
||||
"delete_record",
|
||||
"index_document",
|
||||
"send_notification",
|
||||
]
|
||||
table: str | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
context: ChatContext = Field(default_factory=ChatContext)
|
||||
execution_mode: Literal["direct", "plan"] = "direct"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
actions: list[PlanAction] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Execution Plans ──────────────────────────────────────────────────
|
||||
|
||||
class PlanStep(BaseModel):
|
||||
action: str
|
||||
prompt_template: str | None = None
|
||||
variables: dict[str, Any] | None = None
|
||||
data_from_step: int | None = None
|
||||
|
||||
|
||||
class ExecutionPlan(BaseModel):
|
||||
agent: str
|
||||
steps: list[PlanStep] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Backup ───────────────────────────────────────────────────────────
|
||||
|
||||
@@ -4,6 +4,7 @@ gunicorn>=22.0.0
|
||||
langchain>=0.3.0
|
||||
langchain-openai>=0.3.0
|
||||
langchain-litellm>=0.1.0
|
||||
langgraph>=0.3.0
|
||||
litellm>=1.50.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
class _StubAgent(ChatAgent):
|
||||
"""Minimal concrete agent for testing."""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "stub"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "A stub agent for tests"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"echo: {query}"
|
||||
|
||||
|
||||
class _AnotherAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "another"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Another stub"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return "another"
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_registry():
|
||||
"""Reset the singleton between tests."""
|
||||
AgentRegistry._instance = None
|
||||
yield
|
||||
AgentRegistry._instance = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def reg() -> AgentRegistry:
|
||||
return AgentRegistry()
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRegisterAndGet:
|
||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
agent = reg.get("stub")
|
||||
assert isinstance(agent, _StubAgent)
|
||||
|
||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
reg.get("nonexistent")
|
||||
|
||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
reg.register(_AnotherAgent)
|
||||
assert reg.get("stub").get_name() == "stub"
|
||||
assert reg.get("another").get_name() == "another"
|
||||
|
||||
|
||||
class TestListAgents:
|
||||
def test_empty(self, reg: AgentRegistry) -> None:
|
||||
assert reg.list_agents() == []
|
||||
|
||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
agents = reg.list_agents()
|
||||
assert len(agents) == 1
|
||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
||||
|
||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
reg.register(_AnotherAgent)
|
||||
names = {a["name"] for a in reg.list_agents()}
|
||||
assert names == {"stub", "another"}
|
||||
|
||||
|
||||
class TestCallAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
result = await reg.call_agent("stub", "hello", {})
|
||||
assert result == "echo: hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await reg.call_agent("nope", "hi", {})
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
def test_singleton_identity(self) -> None:
|
||||
a = AgentRegistry()
|
||||
b = AgentRegistry()
|
||||
assert a is b
|
||||
|
||||
|
||||
class TestToolLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_calls(self) -> None:
|
||||
"""When the LLM responds without tool calls, return content directly."""
|
||||
agent = _StubAgent()
|
||||
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.content = "final answer"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
||||
|
||||
result = await agent._tool_loop(llm, [], [])
|
||||
assert result == "final answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_then_answer(self) -> None:
|
||||
"""LLM requests one tool call, gets result, then answers."""
|
||||
agent = _StubAgent()
|
||||
|
||||
# First response: tool call
|
||||
tool_call_msg = MagicMock()
|
||||
tool_call_msg.content = ""
|
||||
tool_call_msg.tool_calls = [
|
||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
||||
]
|
||||
|
||||
# Second response: final answer
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "done"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
# Mock tool
|
||||
tool = AsyncMock()
|
||||
tool.name = "my_tool"
|
||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
||||
|
||||
result = await agent._tool_loop(llm, [], [tool])
|
||||
assert result == "done"
|
||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_handled(self) -> None:
|
||||
"""Unknown tool names produce an error message instead of crashing."""
|
||||
agent = _StubAgent()
|
||||
|
||||
tool_call_msg = MagicMock()
|
||||
tool_call_msg.content = ""
|
||||
tool_call_msg.tool_calls = [
|
||||
{"id": "call_1", "name": "missing", "args": {}}
|
||||
]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "recovered"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
result = await agent._tool_loop(llm, [], [])
|
||||
assert result == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iter_reached(self) -> None:
|
||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
||||
agent = _StubAgent()
|
||||
|
||||
# Every response requests a tool call
|
||||
loop_msg = MagicMock()
|
||||
loop_msg.content = ""
|
||||
loop_msg.tool_calls = [
|
||||
{"id": "call_x", "name": "t", "args": {}}
|
||||
]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "gave up"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
tool = AsyncMock()
|
||||
tool.name = "t"
|
||||
tool.ainvoke = AsyncMock(return_value="ok")
|
||||
|
||||
llm_with_tools = AsyncMock()
|
||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||
|
||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
||||
assert result == "gave up"
|
||||
assert llm_with_tools.ainvoke.call_count == 2
|
||||
@@ -1,416 +0,0 @@
|
||||
"""Tests for ChatAgent streaming and tool result capture (Step 2)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
|
||||
|
||||
# ── Minimal concrete agent for testing ───────────────────────────────
|
||||
|
||||
|
||||
class _EchoAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "_echo"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Echo agent for tests"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return query
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage:
|
||||
msg = AIMessage(content=content)
|
||||
if tool_calls:
|
||||
msg.tool_calls = tool_calls
|
||||
else:
|
||||
msg.tool_calls = []
|
||||
return msg
|
||||
|
||||
|
||||
def _make_tool(name: str, return_value: Any) -> MagicMock:
|
||||
t = MagicMock()
|
||||
t.name = name
|
||||
t.ainvoke = AsyncMock(return_value=return_value)
|
||||
return t
|
||||
|
||||
|
||||
def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]:
|
||||
chunks = []
|
||||
for tok in tokens:
|
||||
c = MagicMock()
|
||||
c.content = tok
|
||||
chunks.append(c)
|
||||
return chunks
|
||||
|
||||
|
||||
async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]:
|
||||
tokens: list[str] = []
|
||||
async for tok in agent._tool_loop_stream(llm, messages, tools):
|
||||
tokens.append(tok)
|
||||
return tokens
|
||||
|
||||
|
||||
# ── tool_results initialised ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tool_results_init():
|
||||
agent = _EchoAgent()
|
||||
assert agent.tool_results == []
|
||||
|
||||
|
||||
# ── _tool_loop: no tool calls ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_no_tools():
|
||||
agent = _EchoAgent()
|
||||
llm = AsyncMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!"))
|
||||
|
||||
result = await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
||||
assert result == "Hello!"
|
||||
assert agent.tool_results == []
|
||||
|
||||
|
||||
# ── _tool_loop: with one tool call + result capture ──────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_captures_tool_results():
|
||||
agent = _EchoAgent()
|
||||
|
||||
# Mock execute_on_client to return structured data via the tool
|
||||
raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]}
|
||||
|
||||
async def fake_executor(payload: dict) -> dict:
|
||||
return raw_result
|
||||
|
||||
# AIMessage with a tool call, then a final answer
|
||||
tool_call_msg = _make_ai_message(
|
||||
tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}]
|
||||
)
|
||||
final_msg = _make_ai_message("Here are your tasks.")
|
||||
|
||||
llm = MagicMock()
|
||||
llm_with_tools = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||
|
||||
mock_tool = _make_tool("list_tasks", "- Fix bug (todo)")
|
||||
|
||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||
set_client_executor(fake_executor)
|
||||
try:
|
||||
# Patch the tool to actually call execute_on_client
|
||||
async def tool_side_effect(args: dict) -> str:
|
||||
from app.core.ws_context import execute_on_client
|
||||
res = await execute_on_client(action="select", table="tasks")
|
||||
rows = res.get("rows", [])
|
||||
return "\n".join(r["title"] for r in rows)
|
||||
|
||||
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||
|
||||
result = await agent._tool_loop(
|
||||
llm, [HumanMessage(content="list my tasks")], [mock_tool]
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
assert result == "Here are your tasks."
|
||||
assert len(agent.tool_results) == 1
|
||||
assert agent.tool_results[0] == raw_result
|
||||
|
||||
|
||||
# ── _tool_loop: tool_results reset on each call ──────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_resets_tool_results():
|
||||
agent = _EchoAgent()
|
||||
agent.tool_results = [{"stale": True}] # pre-populated from a previous call
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done."))
|
||||
|
||||
await agent._tool_loop(llm, [HumanMessage(content="hi")], [])
|
||||
assert agent.tool_results == []
|
||||
|
||||
|
||||
# ── _tool_loop: unknown tool name ────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_unknown_tool():
|
||||
agent = _EchoAgent()
|
||||
|
||||
# No known tools — model still calls a non-existent one; loop handles gracefully
|
||||
tool_call_msg = _make_ai_message(
|
||||
tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||
)
|
||||
final_msg = _make_ai_message("Handled.")
|
||||
|
||||
mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent"
|
||||
llm = MagicMock()
|
||||
llm_with_tools = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool])
|
||||
assert result == "Handled."
|
||||
|
||||
|
||||
# ── _tool_loop: max_iter exhaustion ──────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_max_iter():
|
||||
agent = _EchoAgent()
|
||||
|
||||
always_tool = _make_ai_message(
|
||||
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||
)
|
||||
fallback = _make_ai_message("Fallback.")
|
||||
|
||||
llm = MagicMock()
|
||||
llm_with_tools = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
# Returns tool_call_msg on every iteration
|
||||
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
||||
llm.ainvoke = AsyncMock(return_value=fallback)
|
||||
|
||||
mock_tool = _make_tool("t", "ok")
|
||||
|
||||
result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2)
|
||||
assert result == "Fallback."
|
||||
assert llm_with_tools.ainvoke.call_count == 2
|
||||
|
||||
|
||||
# ── _tool_loop_stream: no tool calls — yields tokens ─────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_stream_no_tools_yields_tokens():
|
||||
agent = _EchoAgent()
|
||||
|
||||
# No tools → llm used directly; ainvoke returns no tool calls → stream is used
|
||||
no_tool_msg = _make_ai_message("irrelevant")
|
||||
llm = AsyncMock()
|
||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||
|
||||
async def fake_astream(msgs):
|
||||
for tok in ["Hello", " ", "world"]:
|
||||
c = MagicMock()
|
||||
c.content = tok
|
||||
yield c
|
||||
|
||||
llm.astream = fake_astream
|
||||
|
||||
tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], [])
|
||||
assert tokens == ["Hello", " ", "world"]
|
||||
assert agent.tool_results == []
|
||||
|
||||
|
||||
# ── _tool_loop_stream: one tool call then streaming final ─────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_stream_with_tool_call():
|
||||
agent = _EchoAgent()
|
||||
|
||||
raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}}
|
||||
|
||||
async def fake_executor(payload: dict) -> dict:
|
||||
return raw_result
|
||||
|
||||
tool_call_msg = _make_ai_message(
|
||||
tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}]
|
||||
)
|
||||
# After tools run, ainvoke returns no more tool calls
|
||||
no_more_tools_msg = _make_ai_message("Task found.")
|
||||
|
||||
llm = MagicMock()
|
||||
llm_with_tools = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
||||
|
||||
async def fake_astream(msgs):
|
||||
for tok in ["Task", " ", "found."]:
|
||||
c = MagicMock()
|
||||
c.content = tok
|
||||
yield c
|
||||
|
||||
llm.astream = fake_astream
|
||||
|
||||
async def tool_side_effect(args: dict) -> str:
|
||||
from app.core.ws_context import execute_on_client
|
||||
res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")})
|
||||
return res.get("row", {}).get("title", "")
|
||||
|
||||
mock_tool = _make_tool("get_task", "Deploy")
|
||||
mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||
|
||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||
set_client_executor(fake_executor)
|
||||
try:
|
||||
tokens = await _collect_stream(
|
||||
agent, llm, [HumanMessage(content="get task t-2")], [mock_tool]
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
assert tokens == ["Task", " ", "found."]
|
||||
assert len(agent.tool_results) == 1
|
||||
assert agent.tool_results[0] == raw_result
|
||||
|
||||
|
||||
# ── _tool_loop_stream: tool_results reset on each call ───────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_stream_resets_tool_results():
|
||||
agent = _EchoAgent()
|
||||
agent.tool_results = [{"old": True}]
|
||||
|
||||
no_tool_msg = _make_ai_message("")
|
||||
llm = AsyncMock()
|
||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||
|
||||
async def fake_astream(msgs):
|
||||
c = MagicMock()
|
||||
c.content = "ok"
|
||||
yield c
|
||||
|
||||
llm.astream = fake_astream
|
||||
|
||||
await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
||||
assert agent.tool_results == []
|
||||
|
||||
|
||||
# ── _tool_loop_stream: empty chunk content is skipped ────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_stream_skips_empty_chunks():
|
||||
agent = _EchoAgent()
|
||||
no_tool_msg = _make_ai_message("")
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.ainvoke = AsyncMock(return_value=no_tool_msg)
|
||||
|
||||
async def fake_astream(msgs):
|
||||
for tok in ["", "hello", "", " world", ""]:
|
||||
c = MagicMock()
|
||||
c.content = tok
|
||||
yield c
|
||||
|
||||
llm.astream = fake_astream
|
||||
|
||||
tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], [])
|
||||
assert tokens == ["hello", " world"]
|
||||
|
||||
|
||||
# ── _tool_loop_stream: max_iter exhaustion falls back to stream ───────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_stream_max_iter():
|
||||
agent = _EchoAgent()
|
||||
|
||||
always_tool = _make_ai_message(
|
||||
tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}]
|
||||
)
|
||||
|
||||
llm = MagicMock()
|
||||
llm_with_tools = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm_with_tools.ainvoke = AsyncMock(return_value=always_tool)
|
||||
|
||||
async def fake_astream(msgs):
|
||||
c = MagicMock()
|
||||
c.content = "fallback"
|
||||
yield c
|
||||
|
||||
llm.astream = fake_astream
|
||||
mock_tool = _make_tool("t", "ok")
|
||||
|
||||
tokens = await _collect_stream(
|
||||
agent, llm, [HumanMessage(content="x")], [mock_tool],
|
||||
)
|
||||
assert tokens == ["fallback"]
|
||||
assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter
|
||||
|
||||
|
||||
# ── _tool_loop_stream: multiple tool results captured ────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_loop_stream_multiple_tool_results():
|
||||
agent = _EchoAgent()
|
||||
|
||||
call_results = [
|
||||
{"rows": [{"id": "t-1"}]},
|
||||
{"rows": [{"id": "t-2"}]},
|
||||
]
|
||||
call_iter = iter(call_results)
|
||||
|
||||
async def fake_executor(payload: dict) -> dict:
|
||||
return next(call_iter)
|
||||
|
||||
# Two tool calls in one iteration
|
||||
tool_call_msg = _make_ai_message(
|
||||
tool_calls=[
|
||||
{"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"},
|
||||
{"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"},
|
||||
]
|
||||
)
|
||||
no_more_tools_msg = _make_ai_message("Done.")
|
||||
|
||||
llm = MagicMock()
|
||||
llm_with_tools = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg])
|
||||
|
||||
async def fake_astream(msgs):
|
||||
c = MagicMock()
|
||||
c.content = "Done."
|
||||
yield c
|
||||
|
||||
llm.astream = fake_astream
|
||||
|
||||
async def tool_side_effect(args: dict) -> str:
|
||||
from app.core.ws_context import execute_on_client
|
||||
res = await execute_on_client(action="select", table="tasks")
|
||||
return str(res)
|
||||
|
||||
tool_a = _make_tool("tool_a", "")
|
||||
tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||
tool_b = _make_tool("tool_b", "")
|
||||
tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect)
|
||||
|
||||
from app.core.ws_context import set_client_executor, clear_client_executor
|
||||
set_client_executor(fake_executor)
|
||||
try:
|
||||
tokens = await _collect_stream(
|
||||
agent, llm, [HumanMessage(content="x")], [tool_a, tool_b]
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
assert tokens == ["Done."]
|
||||
assert len(agent.tool_results) == 2
|
||||
assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]}
|
||||
assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}
|
||||
@@ -1,761 +0,0 @@
|
||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||
from app.agents.timeline_agent import TimelineAgent
|
||||
from app.agents.note_agent import NoteAgent
|
||||
from app.agents.project_agent import ProjectAgent
|
||||
from app.agents.task_agent import TaskAgent
|
||||
from app.core.agent_registry import registry
|
||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
|
||||
|
||||
# ── WS executor mock ──────────────────────────────────────────────────
|
||||
#
|
||||
# Tools call execute_on_client() which reads a ContextVar set by the WS
|
||||
# handler. In unit tests there is no WS session, so we install a fake
|
||||
# executor that returns plausible data for each action type.
|
||||
|
||||
_FAKE_ROW: dict[str, Any] = {
|
||||
"id": "fake-id",
|
||||
"title": "Fake Title",
|
||||
"name": "Fake Name",
|
||||
"status": "todo",
|
||||
"priority": "medium",
|
||||
"content": "Fake content",
|
||||
"date": 1700000000000,
|
||||
"taskId": "fake-task-id",
|
||||
"author": "Alice",
|
||||
"projectId": None,
|
||||
}
|
||||
|
||||
|
||||
async def _fake_executor(payload: dict) -> dict:
|
||||
action = payload.get("action", "")
|
||||
if action == "select":
|
||||
return {"rows": []}
|
||||
if action == "insert":
|
||||
data = payload.get("data", {})
|
||||
return {"row": {**_FAKE_ROW, **data}}
|
||||
if action == "update":
|
||||
data = payload.get("data", {})
|
||||
row = {**_FAKE_ROW, "id": data.get("id", "fake-id"), **data.get("updates", {})}
|
||||
return {"row": row}
|
||||
if action == "delete":
|
||||
return {"deleted": True}
|
||||
if action == "get":
|
||||
data = payload.get("data", {})
|
||||
return {"row": {**_FAKE_ROW, "id": data.get("id", "fake-id")}}
|
||||
if action == "vector_upsert":
|
||||
return {"ok": True}
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ws_executor():
|
||||
"""Install a fake WS executor for every test so tools can run without a real WS."""
|
||||
set_client_executor(_fake_executor)
|
||||
yield
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_llm(response_text: str) -> MagicMock:
|
||||
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
||||
msg = MagicMock()
|
||||
msg.content = response_text
|
||||
msg.tool_calls = []
|
||||
llm = MagicMock()
|
||||
bound = MagicMock()
|
||||
bound.ainvoke = AsyncMock(return_value=msg)
|
||||
llm.bind_tools = MagicMock(return_value=bound)
|
||||
llm.ainvoke = AsyncMock(return_value=msg)
|
||||
return llm
|
||||
|
||||
|
||||
def _mock_llm_with_tool_call(
|
||||
tool_name: str, tool_args: dict[str, Any], final_text: str
|
||||
) -> MagicMock:
|
||||
"""Mock LLM that fires one tool call then returns *final_text*."""
|
||||
tool_msg = MagicMock()
|
||||
tool_msg.content = ""
|
||||
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = final_text
|
||||
final_msg.tool_calls = []
|
||||
|
||||
bound = MagicMock()
|
||||
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
||||
|
||||
llm = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=bound)
|
||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||
return llm
|
||||
|
||||
|
||||
# ── Registration ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentRegistration:
|
||||
def test_all_agents_registered(self) -> None:
|
||||
names = {a["name"] for a in registry.list_agents()}
|
||||
assert {
|
||||
"task_agent", "timeline_agent", "project_agent", "note_agent"
|
||||
}.issubset(names)
|
||||
|
||||
def test_registry_returns_correct_types(self) -> None:
|
||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
||||
assert isinstance(registry.get("timeline_agent"), TimelineAgent)
|
||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
||||
|
||||
def test_descriptions_present(self) -> None:
|
||||
for agent_info in registry.list_agents():
|
||||
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
||||
|
||||
|
||||
# ── TaskAgent ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTaskAgent:
|
||||
def test_name(self) -> None:
|
||||
assert TaskAgent().get_name() == "task_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(TaskAgent().get_tools()) == 8
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in TaskAgent().get_tools()}
|
||||
assert names == {
|
||||
"list_tasks",
|
||||
"create_task",
|
||||
"update_task",
|
||||
"delete_task",
|
||||
"list_tasks_due_today",
|
||||
"list_task_comments",
|
||||
"add_task_comment",
|
||||
"delete_task_comment",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_string(self) -> None:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Task created.")
|
||||
result = await TaskAgent().handle("create a task", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
||||
result = await TaskAgent().handle("list my tasks", {})
|
||||
assert result == "Here are your tasks."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_task",
|
||||
{"title": "Buy groceries", "priority": "low"},
|
||||
"Task 'Buy groceries' created.",
|
||||
)
|
||||
result = await TaskAgent().handle("add a grocery task", {})
|
||||
assert result == "Task 'Buy groceries' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await TaskAgent().handle("help", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_rich_context(self) -> None:
|
||||
context = {
|
||||
"user_profile": {"id": "u1", "tier": "pro"},
|
||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
||||
}
|
||||
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
||||
result = await TaskAgent().handle("show tasks", context)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestTaskAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_defaults(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_tasks.ainvoke({})
|
||||
m.assert_called_once_with(
|
||||
action="select", table="tasks",
|
||||
filters={"projectId": None, "status": None, "search": None, "orderBy": None},
|
||||
)
|
||||
assert result == "No tasks found matching the given filters."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_with_status_filter(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
await list_tasks.ainvoke({"status": "done"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["filters"]["status"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_defaults(self) -> None:
|
||||
from app.agents.task_agent import create_task
|
||||
fake_row = {"id": "t1", "title": "Test task", "status": "todo", "priority": "medium"}
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await create_task.ainvoke({"title": "Test task"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "insert"
|
||||
assert call_kwargs["table"] == "tasks"
|
||||
assert call_kwargs["data"]["title"] == "Test task"
|
||||
assert call_kwargs["data"]["status"] == "todo"
|
||||
assert call_kwargs["data"]["priority"] == "medium"
|
||||
assert "Test task" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_with_all_fields(self) -> None:
|
||||
from app.agents.task_agent import create_task
|
||||
fake_row = {"id": "t1", "title": "Deploy", "status": "in_progress", "priority": "high"}
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await create_task.ainvoke({
|
||||
"title": "Deploy", "priority": "high", "status": "in_progress",
|
||||
"project_id": "p1", "is_ai_suggested": 1,
|
||||
})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["data"]["priority"] == "high"
|
||||
assert call_kwargs["data"]["status"] == "in_progress"
|
||||
assert call_kwargs["data"]["projectId"] == "p1"
|
||||
assert call_kwargs["data"]["isAiSuggested"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_with_status(self) -> None:
|
||||
from app.agents.task_agent import update_task
|
||||
fake_row = {"id": "t1", "title": "Buy groceries", "status": "done"}
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "update"
|
||||
assert call_kwargs["data"]["id"] == "t1"
|
||||
assert call_kwargs["data"]["updates"]["status"] == "done"
|
||||
assert "t1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_empty_updates(self) -> None:
|
||||
from app.agents.task_agent import update_task
|
||||
fake_row = {"id": "t1", "title": "Task", "status": "todo"}
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await update_task.ainvoke({"task_id": "t1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task(self) -> None:
|
||||
from app.agents.task_agent import delete_task
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"deleted": True}
|
||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "delete"
|
||||
assert call_kwargs["table"] == "tasks"
|
||||
assert call_kwargs["data"]["id"] == "t1"
|
||||
assert "t1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_due_today(self) -> None:
|
||||
from app.agents.task_agent import list_tasks_due_today
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_tasks_due_today.ainvoke({})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "select"
|
||||
assert call_kwargs["table"] == "tasks"
|
||||
assert "dueDateFrom" in call_kwargs["filters"]
|
||||
assert result == "No tasks are due today."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_task_comments(self) -> None:
|
||||
from app.agents.task_agent import list_task_comments
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "select"
|
||||
assert call_kwargs["table"] == "taskComments"
|
||||
assert call_kwargs["filters"]["taskId"] == "t1"
|
||||
assert "t1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_task_comment(self) -> None:
|
||||
from app.agents.task_agent import add_task_comment
|
||||
fake_row = {"id": "c1", "taskId": "t1", "author": "Alice", "content": "Looks good!"}
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await add_task_comment.ainvoke({
|
||||
"task_id": "t1", "author": "Alice", "content": "Looks good!",
|
||||
})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "insert"
|
||||
assert call_kwargs["table"] == "taskComments"
|
||||
assert call_kwargs["data"]["taskId"] == "t1"
|
||||
assert call_kwargs["data"]["author"] == "Alice"
|
||||
assert call_kwargs["data"]["content"] == "Looks good!"
|
||||
assert "Alice" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task_comment(self) -> None:
|
||||
from app.agents.task_agent import delete_task_comment
|
||||
with patch("app.agents.task_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"deleted": True}
|
||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "delete"
|
||||
assert call_kwargs["table"] == "taskComments"
|
||||
assert call_kwargs["data"]["id"] == "c1"
|
||||
assert "c1" in result
|
||||
|
||||
|
||||
# ── TimelineAgent ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTimelineAgent:
|
||||
def test_name(self) -> None:
|
||||
assert TimelineAgent().get_name() == "timeline_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert TimelineAgent().get_description() == "Manages project timelines (milestones): list, create, update, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(TimelineAgent().get_tools()) == 4
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in TimelineAgent().get_tools()}
|
||||
assert names == {"list_timelines", "create_timeline", "update_timeline", "delete_timeline"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("No timelines found.")
|
||||
result = await TimelineAgent().handle("list timelines", {})
|
||||
assert result == "No timelines found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_tool_call(self) -> None:
|
||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_timeline",
|
||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
||||
"Timeline 'MVP Launch' created.",
|
||||
)
|
||||
result = await TimelineAgent().handle("add MVP timeline", {})
|
||||
assert result == "Timeline 'MVP Launch' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.timeline_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await TimelineAgent().handle("show milestones", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestTimelineAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_timelines_no_project(self) -> None:
|
||||
from app.agents.timeline_agent import list_timelines
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_timelines.ainvoke({})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "select"
|
||||
assert call_kwargs["table"] == "timelines"
|
||||
assert call_kwargs["filters"]["projectId"] is None
|
||||
assert result == "No timelines found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_timelines_with_project(self) -> None:
|
||||
from app.agents.timeline_agent import list_timelines
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
await list_timelines.ainvoke({"project_id": "p1"})
|
||||
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_timeline(self) -> None:
|
||||
from app.agents.timeline_agent import create_timeline
|
||||
fake_row = {"id": "cp1", "title": "Beta release", "date": 1700000000000}
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await create_timeline.ainvoke({
|
||||
"project_id": "p1", "title": "Beta release", "date": 1700000000000,
|
||||
})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "insert"
|
||||
assert call_kwargs["table"] == "timelines"
|
||||
assert call_kwargs["data"]["projectId"] == "p1"
|
||||
assert call_kwargs["data"]["title"] == "Beta release"
|
||||
assert call_kwargs["data"]["date"] == 1700000000000
|
||||
assert "Beta release" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_timeline_ai_suggested(self) -> None:
|
||||
from app.agents.timeline_agent import create_timeline
|
||||
fake_row = {"id": "cp1", "title": "Review", "date": 1700000000000}
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await create_timeline.ainvoke({
|
||||
"project_id": "p1", "title": "Review", "date": 1700000000000, "is_ai_suggested": 1,
|
||||
})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["data"]["isAiSuggested"] == 1
|
||||
assert call_kwargs["data"]["isApproved"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_timeline_approve(self) -> None:
|
||||
from app.agents.timeline_agent import update_timeline
|
||||
fake_row = {"id": "c1", "title": "MVP", "isApproved": 1}
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await update_timeline.ainvoke({"timeline_id": "c1", "is_approved": 1})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "update"
|
||||
assert call_kwargs["data"]["id"] == "c1"
|
||||
assert call_kwargs["data"]["updates"]["isApproved"] == 1
|
||||
assert "c1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_timeline_empty_updates(self) -> None:
|
||||
from app.agents.timeline_agent import update_timeline
|
||||
fake_row = {"id": "c1", "title": "MVP"}
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await update_timeline.ainvoke({"timeline_id": "c1"})
|
||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_timeline(self) -> None:
|
||||
from app.agents.timeline_agent import delete_timeline
|
||||
with patch("app.agents.timeline_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"deleted": True}
|
||||
result = await delete_timeline.ainvoke({"timeline_id": "c1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "delete"
|
||||
assert call_kwargs["table"] == "timelines"
|
||||
assert call_kwargs["data"]["id"] == "c1"
|
||||
assert "c1" in result
|
||||
|
||||
|
||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProjectAgent:
|
||||
def test_name(self) -> None:
|
||||
assert ProjectAgent().get_name() == "project_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(ProjectAgent().get_tools()) == 6
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in ProjectAgent().get_tools()}
|
||||
assert names == {
|
||||
"list_projects",
|
||||
"list_all_projects",
|
||||
"get_project",
|
||||
"create_project",
|
||||
"update_project",
|
||||
"delete_project",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
||||
result = await ProjectAgent().handle("show my projects", {})
|
||||
assert result == "Project Alpha is active."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_project",
|
||||
{"name": "Pippo"},
|
||||
"Project 'Pippo' created.",
|
||||
)
|
||||
result = await ProjectAgent().handle("create project Pippo", {})
|
||||
assert result == "Project 'Pippo' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await ProjectAgent().handle("archive old project", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestProjectAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_projects_defaults(self) -> None:
|
||||
from app.agents.project_agent import list_projects
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_projects.ainvoke({})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "select"
|
||||
assert call_kwargs["table"] == "projects"
|
||||
assert call_kwargs["filters"]["includeArchived"] is False
|
||||
assert result == "No projects found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_projects_include_archived(self) -> None:
|
||||
from app.agents.project_agent import list_projects
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
await list_projects.ainvoke({"include_archived": 1})
|
||||
assert m.call_args.kwargs["filters"]["includeArchived"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_projects(self) -> None:
|
||||
from app.agents.project_agent import list_all_projects
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_all_projects.ainvoke({})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "select"
|
||||
assert call_kwargs["table"] == "projects"
|
||||
assert result == "No projects found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project(self) -> None:
|
||||
from app.agents.project_agent import get_project
|
||||
fake_row = {"id": "p1", "name": "Alpha", "status": "active", "clientId": None}
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await get_project.ainvoke({"project_id": "p1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "get"
|
||||
assert call_kwargs["table"] == "projects"
|
||||
assert call_kwargs["data"]["id"] == "p1"
|
||||
assert "Alpha" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_name_only(self) -> None:
|
||||
from app.agents.project_agent import create_project
|
||||
fake_row = {"id": "p1", "name": "Alpha"}
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await create_project.ainvoke({"name": "Alpha"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "insert"
|
||||
assert call_kwargs["data"]["name"] == "Alpha"
|
||||
assert call_kwargs["data"]["clientId"] is None
|
||||
assert "Alpha" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_with_client(self) -> None:
|
||||
from app.agents.project_agent import create_project
|
||||
fake_row = {"id": "p1", "name": "Beta"}
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
||||
assert m.call_args.kwargs["data"]["clientId"] == "cl1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_archive(self) -> None:
|
||||
from app.agents.project_agent import update_project
|
||||
fake_row = {"id": "p1", "name": "Alpha", "status": "archived"}
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "update"
|
||||
assert call_kwargs["data"]["id"] == "p1"
|
||||
assert call_kwargs["data"]["updates"]["status"] == "archived"
|
||||
assert "p1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_empty_updates(self) -> None:
|
||||
from app.agents.project_agent import update_project
|
||||
fake_row = {"id": "p1", "name": "Alpha", "status": "active"}
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await update_project.ainvoke({"project_id": "p1"})
|
||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_project(self) -> None:
|
||||
from app.agents.project_agent import delete_project
|
||||
with patch("app.agents.project_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"deleted": True}
|
||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "delete"
|
||||
assert call_kwargs["data"]["id"] == "p1"
|
||||
assert "p1" in result
|
||||
|
||||
|
||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNoteAgent:
|
||||
def test_name(self) -> None:
|
||||
assert NoteAgent().get_name() == "note_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(NoteAgent().get_tools()) == 5
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in NoteAgent().get_tools()}
|
||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Note created.")
|
||||
result = await NoteAgent().handle("create a note", {})
|
||||
assert result == "Note created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_note",
|
||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
||||
"Note 'Daily log' created.",
|
||||
)
|
||||
result = await NoteAgent().handle("log today's progress", {})
|
||||
assert result == "Note 'Daily log' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await NoteAgent().handle("show notes", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestNoteAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes_no_project(self) -> None:
|
||||
from app.agents.note_agent import list_notes
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
result = await list_notes.ainvoke({})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "select"
|
||||
assert call_kwargs["table"] == "notes"
|
||||
assert call_kwargs["filters"]["projectId"] is None
|
||||
assert result == "No notes found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes_with_project(self) -> None:
|
||||
from app.agents.note_agent import list_notes
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"rows": []}
|
||||
await list_notes.ainvoke({"project_id": "p1"})
|
||||
assert m.call_args.kwargs["filters"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_note(self) -> None:
|
||||
from app.agents.note_agent import get_note
|
||||
fake_row = {"id": "n1", "title": "Daily log", "content": "# Today\nAll good."}
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
result = await get_note.ainvoke({"note_id": "n1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "get"
|
||||
assert call_kwargs["table"] == "notes"
|
||||
assert call_kwargs["data"]["id"] == "n1"
|
||||
assert "Daily log" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note_minimal(self) -> None:
|
||||
from app.agents.note_agent import create_note
|
||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
||||
m.return_value = {"row": fake_row}
|
||||
me.return_value = [0.0] * 1536
|
||||
result = await create_note.ainvoke({"title": "Daily log", "content": "# Today\nAll good."})
|
||||
# First call: insert; second call: vector_upsert
|
||||
first_call = m.call_args_list[0].kwargs
|
||||
assert first_call["action"] == "insert"
|
||||
assert first_call["table"] == "notes"
|
||||
assert first_call["data"]["title"] == "Daily log"
|
||||
assert first_call["data"]["content"] == "# Today\nAll good."
|
||||
assert first_call["data"]["projectId"] is None
|
||||
assert "Daily log" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note_with_project(self) -> None:
|
||||
from app.agents.note_agent import create_note
|
||||
fake_row = {"id": "n1", "title": "Sprint notes", "projectId": "p1"}
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
||||
m.return_value = {"row": fake_row}
|
||||
me.return_value = [0.0] * 1536
|
||||
await create_note.ainvoke({"title": "Sprint notes", "content": "## Sprint 1", "project_id": "p1"})
|
||||
first_call = m.call_args_list[0].kwargs
|
||||
assert first_call["data"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note_content_only(self) -> None:
|
||||
from app.agents.note_agent import update_note
|
||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m, \
|
||||
patch("app.agents.note_agent.embed", new_callable=AsyncMock) as me:
|
||||
m.return_value = {"row": fake_row}
|
||||
me.return_value = [0.0] * 1536
|
||||
result = await update_note.ainvoke({"note_id": "n1", "content": "# Updated content"})
|
||||
first_call = m.call_args_list[0].kwargs
|
||||
assert first_call["action"] == "update"
|
||||
assert first_call["data"]["id"] == "n1"
|
||||
assert first_call["data"]["updates"]["content"] == "# Updated content"
|
||||
assert "title" not in first_call["data"]["updates"]
|
||||
assert "n1" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note_empty_updates(self) -> None:
|
||||
from app.agents.note_agent import update_note
|
||||
fake_row = {"id": "n1", "title": "Daily log", "projectId": None}
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"row": fake_row}
|
||||
await update_note.ainvoke({"note_id": "n1"})
|
||||
assert m.call_args.kwargs["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_note(self) -> None:
|
||||
from app.agents.note_agent import delete_note
|
||||
with patch("app.agents.note_agent.execute_on_client", new_callable=AsyncMock) as m:
|
||||
m.return_value = {"deleted": True}
|
||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
||||
call_kwargs = m.call_args.kwargs
|
||||
assert call_kwargs["action"] == "delete"
|
||||
assert call_kwargs["table"] == "notes"
|
||||
assert call_kwargs["data"]["id"] == "n1"
|
||||
assert "n1" in result
|
||||
@@ -1,286 +0,0 @@
|
||||
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.execution_plan import (
|
||||
ExecutionPlanBuilder,
|
||||
PlanCache,
|
||||
PromptTemplateRegistry,
|
||||
plan_cache,
|
||||
template_registry,
|
||||
)
|
||||
from app.schemas import ExecutionPlan
|
||||
|
||||
|
||||
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPromptTemplateRegistry:
|
||||
def test_register_and_get(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_foo", "You are a foo agent.")
|
||||
assert reg.get("tpl_foo") == "You are a foo agent."
|
||||
|
||||
def test_get_unknown_raises_key_error(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
with pytest.raises(KeyError, match="tpl_missing"):
|
||||
reg.get("tpl_missing")
|
||||
|
||||
def test_has_returns_true_for_registered(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_x", "prompt text")
|
||||
assert reg.has("tpl_x") is True
|
||||
|
||||
def test_has_returns_false_for_unregistered(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
assert reg.has("tpl_missing") is False
|
||||
|
||||
def test_list_ids_returns_all_registered_ids(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_a", "a")
|
||||
reg.register("tpl_b", "b")
|
||||
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
||||
|
||||
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_secret", "top secret prompt")
|
||||
ids = reg.list_ids()
|
||||
assert "top secret prompt" not in ids
|
||||
|
||||
def test_overwrite_existing_template(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_x", "v1")
|
||||
reg.register("tpl_x", "v2")
|
||||
assert reg.get("tpl_x") == "v2"
|
||||
|
||||
def test_empty_registry_has_no_ids(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
assert reg.list_ids() == []
|
||||
|
||||
|
||||
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExecutionPlanBuilder:
|
||||
def test_builds_empty_plan(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").build()
|
||||
assert plan.agent == "task_agent"
|
||||
assert plan.steps == []
|
||||
|
||||
def test_add_step_basic(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("create_task", {"priority": "high"})
|
||||
.build()
|
||||
)
|
||||
assert len(plan.steps) == 1
|
||||
assert plan.steps[0].action == "create_task"
|
||||
assert plan.steps[0].variables == {"priority": "high"}
|
||||
assert plan.steps[0].prompt_template is None
|
||||
assert plan.steps[0].data_from_step is None
|
||||
|
||||
def test_add_step_no_params(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
||||
assert plan.steps[0].variables is None
|
||||
|
||||
def test_add_llm_step(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_llm_step("tpl_task_default", {"message": "hi"})
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[0].action == "llm"
|
||||
assert plan.steps[0].prompt_template == "tpl_task_default"
|
||||
assert plan.steps[0].variables == {"message": "hi"}
|
||||
|
||||
def test_add_llm_step_no_variables(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
||||
assert plan.steps[0].variables is None
|
||||
|
||||
def test_add_data_step(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("fetch_data")
|
||||
.add_data_step("transform", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[1].action == "transform"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_fluent_chaining_returns_builder(self) -> None:
|
||||
builder = ExecutionPlanBuilder("analytics_agent")
|
||||
result = builder.add_step("a")
|
||||
assert result is builder
|
||||
|
||||
def test_fluent_chain_multiple_steps(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("analytics_agent")
|
||||
.add_llm_step("tpl_analytics_default")
|
||||
.add_step("format_output")
|
||||
.add_data_step("store", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
assert len(plan.steps) == 3
|
||||
|
||||
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
||||
with pytest.raises(ValueError, match="data_from_step"):
|
||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
||||
|
||||
def test_build_validates_data_from_step_self_reference(self) -> None:
|
||||
"""data_from_step=0 on the first step (index 0) is invalid."""
|
||||
with pytest.raises(ValueError, match="data_from_step"):
|
||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
||||
|
||||
def test_build_validates_data_from_step_negative(self) -> None:
|
||||
with pytest.raises(ValueError, match="data_from_step"):
|
||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
||||
|
||||
def test_valid_data_from_step_at_index_two(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("step0")
|
||||
.add_step("step1")
|
||||
.add_data_step("step2", data_from_step=1)
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[2].data_from_step == 1
|
||||
|
||||
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("step0")
|
||||
.add_data_step("step1", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_build_returns_new_plan_each_call(self) -> None:
|
||||
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
||||
plan1 = builder.build()
|
||||
plan2 = builder.build()
|
||||
assert plan1 is not plan2
|
||||
assert plan1.steps == plan2.steps
|
||||
|
||||
def test_plan_is_execution_plan_instance(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").build()
|
||||
assert isinstance(plan, ExecutionPlan)
|
||||
|
||||
|
||||
# ── PlanCache ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanCache:
|
||||
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
||||
return ExecutionPlanBuilder(agent).build()
|
||||
|
||||
def test_cache_and_get(self) -> None:
|
||||
cache = PlanCache()
|
||||
plan = self._plan()
|
||||
cache.cache_plan("key1", plan)
|
||||
assert cache.get_plan("key1") is plan
|
||||
|
||||
def test_get_missing_returns_none(self) -> None:
|
||||
cache = PlanCache()
|
||||
assert cache.get_plan("nonexistent") is None
|
||||
|
||||
def test_get_all_playbooks_empty(self) -> None:
|
||||
cache = PlanCache()
|
||||
assert cache.get_all_playbooks() == []
|
||||
|
||||
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
||||
cache = PlanCache()
|
||||
p1, p2 = self._plan("a"), self._plan("b")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k2", p2)
|
||||
playbooks = cache.get_all_playbooks()
|
||||
assert len(playbooks) == 2
|
||||
assert p1 in playbooks
|
||||
assert p2 in playbooks
|
||||
|
||||
def test_lru_evicts_oldest_entry(self) -> None:
|
||||
cache = PlanCache(maxsize=2)
|
||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k2", p2)
|
||||
cache.cache_plan("k3", p3) # k1 should be evicted
|
||||
assert cache.get_plan("k1") is None
|
||||
assert cache.get_plan("k2") is p2
|
||||
assert cache.get_plan("k3") is p3
|
||||
|
||||
def test_lru_access_updates_recency(self) -> None:
|
||||
cache = PlanCache(maxsize=2)
|
||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k2", p2)
|
||||
cache.get_plan("k1") # k1 is now most-recently used
|
||||
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
||||
assert cache.get_plan("k1") is p1
|
||||
assert cache.get_plan("k2") is None
|
||||
assert cache.get_plan("k3") is p3
|
||||
|
||||
def test_overwrite_existing_key(self) -> None:
|
||||
cache = PlanCache()
|
||||
p1, p2 = self._plan("a"), self._plan("b")
|
||||
cache.cache_plan("same_key", p1)
|
||||
cache.cache_plan("same_key", p2)
|
||||
assert cache.get_plan("same_key") is p2
|
||||
assert len(cache.get_all_playbooks()) == 1
|
||||
|
||||
def test_overwrite_does_not_consume_capacity(self) -> None:
|
||||
cache = PlanCache(maxsize=2)
|
||||
p1, p2 = self._plan("a"), self._plan("b")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
||||
cache.cache_plan("k2", p1) # should fit without eviction
|
||||
assert cache.get_plan("k1") is p2
|
||||
assert cache.get_plan("k2") is p1
|
||||
|
||||
|
||||
# ── Module-level singletons ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModuleSingletons:
|
||||
def test_template_registry_has_all_agent_defaults(self) -> None:
|
||||
for agent in ("task_agent", "timeline_agent", "project_agent", "note_agent"):
|
||||
assert template_registry.has(f"tpl_{agent}_default"), (
|
||||
f"Missing template: tpl_{agent}_default"
|
||||
)
|
||||
|
||||
def test_template_registry_has_operation_templates(self) -> None:
|
||||
assert template_registry.has("tpl_task_extract_from_project")
|
||||
assert template_registry.has("tpl_note_weekly_summary")
|
||||
|
||||
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
||||
text = template_registry.get("tpl_task_agent_default")
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
||||
assert len(plan_cache.get_all_playbooks()) >= 2
|
||||
|
||||
def test_playbook_create_tasks_from_project(self) -> None:
|
||||
plan = plan_cache.get_plan("create_tasks_from_project")
|
||||
assert plan is not None
|
||||
assert plan.agent == "project_agent"
|
||||
assert len(plan.steps) == 2
|
||||
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_playbook_generate_weekly_note(self) -> None:
|
||||
plan = plan_cache.get_plan("generate_weekly_note")
|
||||
assert plan is not None
|
||||
assert plan.agent == "note_agent"
|
||||
assert len(plan.steps) == 2
|
||||
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
||||
"""Plans must not embed prompt text — only template IDs."""
|
||||
for plan in plan_cache.get_all_playbooks():
|
||||
for step in plan.steps:
|
||||
if step.prompt_template is not None:
|
||||
assert step.prompt_template.startswith("tpl_"), (
|
||||
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
||||
)
|
||||
@@ -250,15 +250,15 @@ def test_home_request_calls_memory_middleware(client):
|
||||
token = make_jwt("power", user_id=USER_ID)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
async def _mock_stream(user_id, message, context, reg=None):
|
||||
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
||||
# Verify memory context was injected
|
||||
assert context.get("core_memory") == {"tz": "UTC"}
|
||||
yield "task_agent", ""
|
||||
yield "task_agent", '{"type": "text", "content": "Done"}'
|
||||
yield ("token", "Done")
|
||||
yield ("mutations", [])
|
||||
|
||||
with (
|
||||
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||
patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_stream),
|
||||
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
|
||||
):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(json.dumps({
|
||||
|
||||
@@ -20,7 +20,6 @@ from jose import jwt
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.schemas import ChatResponse
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -50,7 +49,6 @@ _CHAT_BODY = {
|
||||
"recent_tasks": [],
|
||||
"conversation_history": [],
|
||||
},
|
||||
"execution_mode": "direct",
|
||||
}
|
||||
|
||||
|
||||
@@ -240,7 +238,7 @@ class TestRateLimitMiddleware:
|
||||
|
||||
|
||||
class TestSanitizerMiddleware:
|
||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
||||
"""Mock ``run_home`` to inject controlled strings into chat responses."""
|
||||
|
||||
_CHAT_PATH = "/api/v1/chat"
|
||||
|
||||
@@ -248,11 +246,10 @@ class TestSanitizerMiddleware:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||
|
||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||
mock_response = ChatResponse(response=response_text, actions=[])
|
||||
with patch(
|
||||
"app.api.routes.chat.orchestrate",
|
||||
"app.api.routes.chat.run_home",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
return_value=response_text,
|
||||
):
|
||||
resp = client.post(
|
||||
self._CHAT_PATH,
|
||||
|
||||
@@ -1,347 +0,0 @@
|
||||
"""Integration tests for the orchestrator module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
from app.core.orchestrator import (
|
||||
classify_intent,
|
||||
orchestrate,
|
||||
orchestrate_stream,
|
||||
route_pipeline,
|
||||
route_single,
|
||||
)
|
||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||
|
||||
|
||||
# ── Stub agents ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _TaskAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "task_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages tasks: create, update, list, suggest"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"task: {query}"
|
||||
|
||||
|
||||
class _CalendarAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "calendar_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Calendar management: events, conflicts, scheduling"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"calendar: {query}"
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_llm(response_text: str) -> MagicMock:
|
||||
"""Return a mock LLM that always produces *response_text*."""
|
||||
msg = MagicMock()
|
||||
msg.content = response_text
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=msg)
|
||||
return llm
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_registry():
|
||||
"""Reset the AgentRegistry singleton between tests."""
|
||||
AgentRegistry._instance = None
|
||||
yield
|
||||
AgentRegistry._instance = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def reg() -> AgentRegistry:
|
||||
r = AgentRegistry()
|
||||
r.register(_TaskAgent)
|
||||
r.register(_CalendarAgent)
|
||||
return r
|
||||
|
||||
|
||||
# ── classify_intent ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
result = await classify_intent("add a task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
result = await classify_intent("schedule a meeting", {}, reg)
|
||||
assert result == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
||||
result = await classify_intent("do something", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
||||
empty_reg = AgentRegistry()
|
||||
# No LLM should be instantiated — early return path
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
result = await classify_intent("anything", {}, empty_reg)
|
||||
mock_cls.assert_not_called()
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
||||
result = await classify_intent("create task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
|
||||
# ── route_single ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRouteSingle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "create a task", {}, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "create a task", {}, reg)
|
||||
assert result.response == "task: create a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await route_single("nonexistent", "hello", {}, reg)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "hi", {}, reg)
|
||||
assert result.actions == []
|
||||
|
||||
|
||||
# ── route_pipeline ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRoutePipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
)
|
||||
assert result.response == "synthesized result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_previous_results_to_subsequent_agents(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
"""Each agent after the first should receive prior outputs in context."""
|
||||
received_contexts: list[dict[str, Any]] = []
|
||||
|
||||
class _CapturingAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "capture"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "captures context for testing"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
received_contexts.append(dict(context))
|
||||
return "captured"
|
||||
|
||||
reg.register(_CapturingAgent)
|
||||
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("done")
|
||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
||||
|
||||
# The second agent (capture) must have received previous results
|
||||
assert len(received_contexts) == 1
|
||||
assert "previous_results" in received_contexts[0]
|
||||
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("single result")
|
||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
||||
assert result.response == "single result"
|
||||
|
||||
|
||||
# ── orchestrate ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestrate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_returns_chat_response(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
assert result.response == "task: add a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_returns_execution_plan(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_agent_matches_classified(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
request = ChatRequest(
|
||||
message="schedule something", execution_mode="plan"
|
||||
)
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert result.agent == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert len(result.steps) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_template_id_contains_agent_name(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert result.steps[0].prompt_template is not None
|
||||
assert "task_agent" in result.steps[0].prompt_template
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_execution_mode_is_direct(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
# execution_mode defaults to "direct"
|
||||
request = ChatRequest(message="help me")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
|
||||
# ── orchestrate_stream ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestrateStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
assert len(chunks) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_chunks_are_plain_text(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
# orchestrate_stream yields plain text chunks only — no JSON final frame
|
||||
for chunk in chunks:
|
||||
assert isinstance(chunk, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concatenated_chunks_equal_full_response(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
full_text = "".join(chunks)
|
||||
assert full_text == "task: create a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chunks_before_final_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(
|
||||
message="x" * 200, execution_mode="direct"
|
||||
) # long enough to produce multiple chunks
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
# All but the last chunk should be plain text (not valid final JSON)
|
||||
non_final = chunks[:-1]
|
||||
for chunk in non_final:
|
||||
try:
|
||||
parsed = json.loads(chunk)
|
||||
assert parsed.get("done") is not True
|
||||
except json.JSONDecodeError:
|
||||
pass # plain text chunk — expected
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Tests for v3 orchestrator functions (Step 3)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Any
|
||||
|
||||
from app.core.agent_registry import ChatAgent, AgentRegistry
|
||||
from app.core.orchestrator import orchestrate_v3, orchestrate_v3_stream
|
||||
|
||||
|
||||
# ── Minimal agent for testing ─────────────────────────────────────────
|
||||
|
||||
|
||||
class _FixedAgent(ChatAgent):
|
||||
def __init__(self, name: str = "_fixed", tokens: list[str] | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._name = name
|
||||
self._tokens = tokens or ["Hello", " world"]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Fixed agent for tests"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return "".join(self._tokens)
|
||||
|
||||
async def handle_stream(self, query: str, context: dict[str, Any]):
|
||||
for tok in self._tokens:
|
||||
yield tok
|
||||
|
||||
|
||||
# ── Mock registry factory ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_registry(agent_name: str, agent: ChatAgent) -> MagicMock:
|
||||
reg = MagicMock(spec=AgentRegistry)
|
||||
reg.list_agents.return_value = [{"name": agent_name, "description": "test"}]
|
||||
reg.get.return_value = agent
|
||||
return reg
|
||||
|
||||
|
||||
# ── orchestrate_v3 ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_returns_agent_name_and_instance():
|
||||
agent = _FixedAgent("task_agent")
|
||||
reg = _make_registry("task_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||
name, inst = await orchestrate_v3(
|
||||
user_id="u-1", message="fix a bug", context={}, reg=reg
|
||||
)
|
||||
|
||||
assert name == "task_agent"
|
||||
assert inst is agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_classify_called_with_message_and_context():
|
||||
agent = _FixedAgent("note_agent")
|
||||
reg = _make_registry("note_agent", agent)
|
||||
ctx = {"some": "context"}
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")) as mock_classify:
|
||||
await orchestrate_v3(user_id="u-1", message="take a note", context=ctx, reg=reg)
|
||||
|
||||
mock_classify.assert_awaited_once()
|
||||
call_args = mock_classify.call_args
|
||||
assert call_args[0][0] == "take a note"
|
||||
assert call_args[0][1] == ctx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_uses_default_registry_when_none():
|
||||
agent = _FixedAgent("task_agent")
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
||||
patch("app.core.orchestrator._default_registry") as mock_reg:
|
||||
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
||||
mock_reg.get.return_value = agent
|
||||
name, inst = await orchestrate_v3(user_id="u-1", message="hi", context={})
|
||||
|
||||
assert name == "task_agent"
|
||||
assert inst is agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_get_called_with_agent_name():
|
||||
agent = _FixedAgent("timeline_agent")
|
||||
reg = _make_registry("timeline_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="timeline_agent")):
|
||||
await orchestrate_v3(user_id="u-2", message="schedule", context={}, reg=reg)
|
||||
|
||||
reg.get.assert_called_once_with("timeline_agent")
|
||||
|
||||
|
||||
# ── orchestrate_v3_stream ─────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _collect(gen) -> list[tuple[str, str]]:
|
||||
results: list[tuple[str, str]] = []
|
||||
async for item in gen:
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_stream_first_yield_is_domain_signal():
|
||||
agent = _FixedAgent("task_agent", tokens=["token1"])
|
||||
reg = _make_registry("task_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||
results = await _collect(gen)
|
||||
|
||||
# First item must be (agent_name, "") — domain signal
|
||||
assert results[0] == ("task_agent", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_stream_yields_agent_name_with_tokens():
|
||||
agent = _FixedAgent("task_agent", tokens=["Hello", " ", "world"])
|
||||
reg = _make_registry("task_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||
results = await _collect(gen)
|
||||
|
||||
# All items are (agent_name, token) pairs
|
||||
assert all(name == "task_agent" for name, _ in results)
|
||||
tokens = [tok for _, tok in results]
|
||||
assert tokens[0] == "" # domain signal
|
||||
assert tokens[1:] == ["Hello", " ", "world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_stream_different_agent():
|
||||
agent = _FixedAgent("note_agent", tokens=["note"])
|
||||
reg = _make_registry("note_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="note_agent")):
|
||||
gen = orchestrate_v3_stream(user_id="u-2", message="take note", context={}, reg=reg)
|
||||
results = await _collect(gen)
|
||||
|
||||
assert results[0] == ("note_agent", "")
|
||||
assert ("note_agent", "note") in results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_stream_uses_default_registry_when_none():
|
||||
agent = _FixedAgent("task_agent", tokens=["x"])
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")), \
|
||||
patch("app.core.orchestrator._default_registry") as mock_reg:
|
||||
mock_reg.list_agents.return_value = [{"name": "task_agent", "description": ""}]
|
||||
mock_reg.get.return_value = agent
|
||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={})
|
||||
results = await _collect(gen)
|
||||
|
||||
assert results[0][0] == "task_agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_stream_empty_token_list():
|
||||
"""Agent with no tokens still emits the domain signal."""
|
||||
|
||||
class _EmptyAgent(_FixedAgent):
|
||||
async def handle_stream(self, query: str, context: dict[str, Any]):
|
||||
return
|
||||
yield # makes it a generator
|
||||
|
||||
agent = _EmptyAgent("task_agent", tokens=[])
|
||||
reg = _make_registry("task_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||
results = await _collect(gen)
|
||||
|
||||
assert results == [("task_agent", "")] # only domain signal
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrate_v3_stream_full_text_correct():
|
||||
"""Concatenating all non-domain tokens reconstructs the full response."""
|
||||
tokens = ["The", " ", "task", " ", "is", " ", "done."]
|
||||
agent = _FixedAgent("task_agent", tokens=tokens)
|
||||
reg = _make_registry("task_agent", agent)
|
||||
|
||||
with patch("app.core.orchestrator.classify_intent", AsyncMock(return_value="task_agent")):
|
||||
gen = orchestrate_v3_stream(user_id="u-1", message="hi", context={}, reg=reg)
|
||||
results = await _collect(gen)
|
||||
|
||||
text = "".join(tok for _, tok in results[1:]) # skip domain signal
|
||||
assert text == "The task is done."
|
||||
|
||||
|
||||
# ── handle_stream default implementation ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_stream_default_yields_full_response():
|
||||
"""Default handle_stream yields handle() result as a single chunk."""
|
||||
|
||||
class _SimpleAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "_simple"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return "simple response"
|
||||
|
||||
agent = _SimpleAgent()
|
||||
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
||||
assert tokens == ["simple response"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_stream_override_used_by_stream():
|
||||
"""_FixedAgent.handle_stream override yields individual tokens."""
|
||||
agent = _FixedAgent("t", tokens=["a", "b", "c"])
|
||||
tokens = [tok async for tok in agent.handle_stream("q", {})]
|
||||
assert tokens == ["a", "b", "c"]
|
||||
@@ -16,15 +16,15 @@ from app.schemas import (
|
||||
|
||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _stream(*pairs: tuple[str, str]):
|
||||
"""Async generator that yields (agent_name, token) pairs."""
|
||||
for pair in pairs:
|
||||
yield pair
|
||||
async def _stream(*events: tuple[str, object]):
|
||||
"""Async generator that yields (event_type, data) tuples."""
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
|
||||
async def collect(formatter, token_stream):
|
||||
async def collect(formatter, event_stream):
|
||||
frames = []
|
||||
async for frame in formatter.format(token_stream):
|
||||
async for frame in formatter.format(event_stream):
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
@@ -32,13 +32,14 @@ async def collect(formatter, token_stream):
|
||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_text_block():
|
||||
async def test_home_formatter_text_token():
|
||||
req_id = "req-1"
|
||||
tokens = [
|
||||
("task_agent", '{"type": "text", "content": "Hello world"}'),
|
||||
events = [
|
||||
("token", "Hello world"),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
assert isinstance(frames[0], WsStreamStart)
|
||||
assert frames[0].request_id == req_id
|
||||
@@ -48,104 +49,94 @@ async def test_home_formatter_text_block():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_chart_block():
|
||||
async def test_home_formatter_entity_ref_from_tool_end():
|
||||
req_id = "req-2"
|
||||
chart_json = (
|
||||
'{"type": "chart", "chartType": "bar", '
|
||||
'"title": "Tasks", "data": [{"x": 1}], '
|
||||
'"config": {"x": {"label": "X", "color": "#fff"}}}'
|
||||
)
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", chart_json)))
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "Found 3 tasks."}),
|
||||
("token", "Here are your tasks."),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "chart"
|
||||
assert block_frames[0].data["chartType"] == "bar"
|
||||
assert block_frames[0].block_type == "entity_ref"
|
||||
assert block_frames[0].data["entity"] == "tasks"
|
||||
assert block_frames[0].data["result"] == "Found 3 tasks."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_invalid_chart_skipped():
|
||||
async def test_home_formatter_unknown_agent_no_block():
|
||||
req_id = "req-3"
|
||||
bad_chart = '{"type": "chart", "chartType": "unknown", "data": []}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", bad_chart)))
|
||||
events = [
|
||||
("tool_end", {"name": "unknown_agent", "result": "stuff"}),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0 # invalid chart skipped
|
||||
assert len(block_frames) == 0 # unknown agent → no entity mapping
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_ref_resolved():
|
||||
async def test_home_formatter_mutations_in_stream_end():
|
||||
req_id = "req-4"
|
||||
tool_results = [{"entity": "task", "id": "t1", "title": "My Task"}]
|
||||
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=tool_results)
|
||||
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||
events = [
|
||||
("token", "Done"),
|
||||
("mutations", muts),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].data["entity"] == "task"
|
||||
assert block_frames[0].data["items"][0]["id"] == "t1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_entity_ref_missing_skipped():
|
||||
req_id = "req-5"
|
||||
entity_json = '{"type": "entity_ref", "entity": "task"}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", entity_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 0 # no tool results → skipped
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_table_block():
|
||||
req_id = "req-6"
|
||||
table_json = '{"type": "table", "headers": ["A", "B"], "rows": [["1", "2"]]}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", table_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "table"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_timeline_block():
|
||||
req_id = "req-7"
|
||||
timeline_json = '{"type": "timeline", "timelines": [{"id": "c1", "title": "M1", "date": 123}]}'
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", timeline_json)))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 1
|
||||
assert block_frames[0].block_type == "timeline"
|
||||
end_frame = frames[-1]
|
||||
assert isinstance(end_frame, WsStreamEnd)
|
||||
assert len(end_frame.mutations) == 1
|
||||
assert end_frame.mutations[0]["action"] == "insert"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_frame_order():
|
||||
"""stream_start is first, stream_end is last."""
|
||||
req_id = "req-8"
|
||||
formatter = HomeFormatter(request_id=req_id, tool_results=[])
|
||||
frames = await collect(formatter, _stream(("task_agent", '{"type": "text", "content": "Hi"}')))
|
||||
req_id = "req-5"
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||
assert isinstance(frames[0], WsStreamStart)
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
# ── FloatingFormatter ────────────────────────────────────────────────────────────
|
||||
@pytest.mark.asyncio
|
||||
async def test_home_formatter_multiple_tool_ends():
|
||||
req_id = "req-6"
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||
("tool_end", {"name": "project_agent", "result": "2 projects"}),
|
||||
("token", "Overview done."),
|
||||
("mutations", []),
|
||||
]
|
||||
formatter = HomeFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
block_frames = [f for f in frames if isinstance(f, WsStreamBlock)]
|
||||
assert len(block_frames) == 2
|
||||
entities = {b.data["entity"] for b in block_frames}
|
||||
assert entities == {"tasks", "projects"}
|
||||
|
||||
|
||||
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_floating_formatter_domain_emitted_first():
|
||||
async def test_floating_formatter_domain_from_tool_end():
|
||||
req_id = "pop-1"
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
tokens = [
|
||||
("task_agent", ""), # domain signal
|
||||
("task_agent", "Hello"),
|
||||
("task_agent", " there"),
|
||||
events = [
|
||||
("tool_end", {"name": "task_agent", "result": "ok"}),
|
||||
("token", "Hello"),
|
||||
("mutations", []),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
assert isinstance(frames[0], WsFloatingDomain)
|
||||
assert frames[0].domain == "tasks"
|
||||
@@ -156,8 +147,12 @@ async def test_floating_formatter_domain_emitted_first():
|
||||
async def test_floating_formatter_text_only():
|
||||
req_id = "pop-2"
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
tokens = [("timeline_agent", ""), ("timeline_agent", "Summary")]
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
events = [
|
||||
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
||||
("token", "Summary"),
|
||||
("mutations", []),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
assert isinstance(frames[0], WsFloatingDomain)
|
||||
assert frames[0].domain == "timelines"
|
||||
@@ -171,11 +166,12 @@ async def test_floating_formatter_no_block_frames():
|
||||
"""FloatingFormatter must never emit WsStreamBlock."""
|
||||
req_id = "pop-3"
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
tokens = [
|
||||
("note_agent", ""),
|
||||
("note_agent", '{"type": "chart", "chartType": "bar", "data": []}'),
|
||||
events = [
|
||||
("tool_end", {"name": "note_agent", "result": "data"}),
|
||||
("token", "some text"),
|
||||
("mutations", []),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*tokens))
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
assert not any(isinstance(f, WsStreamBlock) for f in frames)
|
||||
|
||||
|
||||
@@ -183,13 +179,37 @@ async def test_floating_formatter_no_block_frames():
|
||||
async def test_floating_formatter_end_frame():
|
||||
req_id = "pop-4"
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("project_agent", ""), ("project_agent", "Done")))
|
||||
events = [
|
||||
("tool_end", {"name": "project_agent", "result": "ok"}),
|
||||
("token", "Done"),
|
||||
("mutations", []),
|
||||
]
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_floating_formatter_unknown_agent_defaults_to_tasks():
|
||||
async def test_floating_formatter_default_domain_on_early_token():
|
||||
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
||||
req_id = "pop-5"
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(("unknown_agent", ""), ("unknown_agent", "hi")))
|
||||
events = [("token", "hi"), ("mutations", [])]
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
assert isinstance(frames[0], WsFloatingDomain)
|
||||
assert frames[0].domain == "tasks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_floating_formatter_mutations_in_stream_end():
|
||||
req_id = "pop-6"
|
||||
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
||||
events = [
|
||||
("token", "Updated"),
|
||||
("mutations", muts),
|
||||
]
|
||||
formatter = FloatingFormatter(request_id=req_id)
|
||||
frames = await collect(formatter, _stream(*events))
|
||||
|
||||
end_frame = frames[-1]
|
||||
assert isinstance(end_frame, WsStreamEnd)
|
||||
assert len(end_frame.mutations) == 1
|
||||
|
||||
@@ -88,7 +88,7 @@ class TestPluginRegistry:
|
||||
async def test_list_filter_by_query(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session, query="time")
|
||||
result = await reg.list_plugins(db_session, query="time tracker")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-time-tracker"
|
||||
|
||||
|
||||
@@ -45,14 +45,16 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||
return frames
|
||||
|
||||
|
||||
async def _mock_home_stream(user_id, message, context, reg=None):
|
||||
yield "task_agent", ""
|
||||
yield "task_agent", '{"type": "text", "content": "Hello"}'
|
||||
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||
yield "tool_end", {"name": "task_agent", "result": "Found tasks"}
|
||||
yield "token", "Hello"
|
||||
yield "mutations", []
|
||||
|
||||
|
||||
async def _mock_floating_stream(user_id, message, context, reg=None):
|
||||
yield "task_agent", ""
|
||||
yield "task_agent", "Here is a summary"
|
||||
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
|
||||
yield "tool_end", {"name": "task_agent", "result": "ok"}
|
||||
yield "token", "Here is a summary"
|
||||
yield "mutations", []
|
||||
|
||||
|
||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||
@@ -61,7 +63,7 @@ def test_home_request_produces_stream_frames(client):
|
||||
"""home_request → stream_start, stream_text+, stream_end."""
|
||||
token = make_jwt("power", user_id=USER_ID)
|
||||
|
||||
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_home_stream):
|
||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(json.dumps({
|
||||
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
||||
@@ -84,7 +86,7 @@ def test_floating_request_produces_domain_frame(client):
|
||||
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
||||
token = make_jwt("power", user_id=USER_ID)
|
||||
|
||||
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_mock_floating_stream):
|
||||
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(json.dumps({
|
||||
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||
@@ -112,11 +114,12 @@ def test_home_request_request_id_propagated(client):
|
||||
token = make_jwt("power", user_id=USER_ID)
|
||||
req_id = "my-unique-req-id"
|
||||
|
||||
async def _stream(user_id, message, context, reg=None):
|
||||
yield "note_agent", ""
|
||||
yield "note_agent", '{"type": "text", "content": "ok"}'
|
||||
async def _stream(user_id, message, context, db_session_factory=None):
|
||||
yield "tool_end", {"name": "note_agent", "result": "ok"}
|
||||
yield "token", "ok"
|
||||
yield "mutations", []
|
||||
|
||||
with patch("app.api.routes.device_ws.orchestrate_v3_stream", side_effect=_stream):
|
||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(json.dumps({
|
||||
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
||||
|
||||
Reference in New Issue
Block a user