diff --git a/V3_MIGRATION_PLAN.md b/V3_MIGRATION_PLAN.md index 26844fa..d5da12e 100644 --- a/V3_MIGRATION_PLAN.md +++ b/V3_MIGRATION_PLAN.md @@ -7,6 +7,18 @@ --- +## General Rules + +**Code Cleanup**: As you implement each step, remove any code that becomes unused or obsolete. This includes: +- Old functions/methods that are superseded by new ones +- Deprecated imports or modules +- Dead code paths +- Old test files no longer needed + +This keeps the codebase clean and prevents confusion. When removing code, note it in the commit message if significant. + +--- + ## Decisions Log | Topic | Decision | @@ -74,7 +86,7 @@ pytest tests/test_agent_streaming.py ``` **Status**: -- [ ] Step 2 complete +- [x] Step 2 complete **Commit**: After tests pass, commit with: ``` @@ -222,8 +234,9 @@ git commit -m "step-4: add output formatting layer (output_formatter.py)" - Each request gets a `request_id` (UUID) for frame correlation. - Concurrent requests from same client are supported (each runs as an async task). - `app/api/routes/chat.py`: - - Remove `chat_stream` WS endpoint. + - Remove `chat_stream` WS endpoint and any related helper functions that were only used by it. - Keep `POST /chat` endpoint unchanged (REST fallback). + - Clean up any unused imports. - `app/main.py`: - No change needed (device_ws router already registered). diff --git a/app/core/agent_registry.py b/app/core/agent_registry.py index 1037c14..323e4ea 100644 --- a/app/core/agent_registry.py +++ b/app/core/agent_registry.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator from typing import Any @@ -34,6 +35,11 @@ class BaseAgent(ABC): class ChatAgent(BaseAgent): """Base class for LLM-powered chat agents.""" + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + # Populated by _tool_loop / _tool_loop_stream with raw execute_on_client results. + self.tool_results: list[dict] = [] + @abstractmethod async def handle(self, query: str, context: dict[str, Any]) -> str: """Process a user query and return a text response.""" @@ -55,34 +61,98 @@ class ChatAgent(BaseAgent): Binds *tools* to *llm*, invokes iteratively until the model stops requesting tool calls or *max_iter* is reached, and returns the - final text response. + final text response. Captures raw execute_on_client results in + ``self.tool_results``. """ from langchain_core.messages import AIMessage, ToolMessage - llm_with_tools = llm.bind_tools(tools) if tools else llm + from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector - for _ in range(max_iter): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) + collector: list[dict] = [] + set_tool_result_collector(collector) + try: + llm_with_tools = llm.bind_tools(tools) if tools else llm - if not response.tool_calls: - return str(response.content) + for _ in range(max_iter): + response: AIMessage = await llm_with_tools.ainvoke(messages) + messages.append(response) - # Execute each requested tool call - tool_map = {t.name: t for t in tools} - for call in response.tool_calls: - tool_fn = tool_map.get(call["name"]) - if tool_fn is None: - result = f"Unknown tool: {call['name']}" - else: - result = await tool_fn.ainvoke(call["args"]) - messages.append( - ToolMessage(content=str(result), tool_call_id=call["id"]) - ) + if not response.tool_calls: + return str(response.content) - # Exhausted iterations — ask model for a final answer without tools - response = await llm.ainvoke(messages) - return str(response.content) + # Execute each requested tool call + tool_map = {t.name: t for t in tools} + for call in response.tool_calls: + tool_fn = tool_map.get(call["name"]) + if tool_fn is None: + result = f"Unknown tool: {call['name']}" + else: + result = await tool_fn.ainvoke(call["args"]) + messages.append( + ToolMessage(content=str(result), tool_call_id=call["id"]) + ) + + # Exhausted iterations — ask model for a final answer without tools + response = await llm.ainvoke(messages) + return str(response.content) + finally: + clear_tool_result_collector() + self.tool_results = collector + + async def _tool_loop_stream( + self, + llm: Any, + messages: list[Any], + tools: list[Any], + max_iter: int = 5, + ) -> AsyncGenerator[str, None]: + """Streaming variant of ``_tool_loop``. + + Behaves identically for tool-calling iterations (uses ainvoke to parse + tool calls). For the final response — when the model produces no further + tool calls — switches to ``llm.astream()`` and yields text tokens. + Captures raw execute_on_client results in ``self.tool_results``. + """ + from langchain_core.messages import AIMessage, ToolMessage + + from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector + + collector: list[dict] = [] + set_tool_result_collector(collector) + try: + llm_with_tools = llm.bind_tools(tools) if tools else llm + + for _ in range(max_iter): + response: AIMessage = await llm_with_tools.ainvoke(messages) + + if not response.tool_calls: + # Stream the final answer — don't keep the ainvoke result. + async for chunk in llm.astream(messages): + if chunk.content: + yield str(chunk.content) + return + + messages.append(response) + + # Execute each requested tool call + tool_map = {t.name: t for t in tools} + for call in response.tool_calls: + tool_fn = tool_map.get(call["name"]) + if tool_fn is None: + result = f"Unknown tool: {call['name']}" + else: + result = await tool_fn.ainvoke(call["args"]) + messages.append( + ToolMessage(content=str(result), tool_call_id=call["id"]) + ) + + # Exhausted iterations — stream a final answer without tools + async for chunk in llm.astream(messages): + if chunk.content: + yield str(chunk.content) + finally: + clear_tool_result_collector() + self.tool_results = collector class AgentRegistry: diff --git a/app/core/ws_context.py b/app/core/ws_context.py index f4de713..d669c6e 100644 --- a/app/core/ws_context.py +++ b/app/core/ws_context.py @@ -17,6 +17,22 @@ _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = Cont "_client_executor" ) +# Optional collector that captures raw execute_on_client results. +# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results. +_tool_result_collector: ContextVar[list[dict] | None] = ContextVar( + "_tool_result_collector", default=None +) + + +def set_tool_result_collector(lst: list[dict]) -> None: + """Register *lst* as the collector for this async context.""" + _tool_result_collector.set(lst) + + +def clear_tool_result_collector() -> None: + """Clear the collector (best-effort).""" + _tool_result_collector.set(None) + def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None: """Bind *fn* as the executor for the current async context (task/coroutine).""" @@ -65,4 +81,8 @@ async def execute_on_client( if limit is not None: payload["limit"] = limit - return await callback(payload) + result = await callback(payload) + collector = _tool_result_collector.get(None) + if collector is not None: + collector.append(result) + return result diff --git a/tests/test_agent_streaming.py b/tests/test_agent_streaming.py new file mode 100644 index 0000000..59a8232 --- /dev/null +++ b/tests/test_agent_streaming.py @@ -0,0 +1,416 @@ +"""Tests for ChatAgent streaming and tool result capture (Step 2).""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from app.core.agent_registry import ChatAgent, registry + + +# ── Minimal concrete agent for testing ─────────────────────────────── + + +class _EchoAgent(ChatAgent): + def get_name(self) -> str: + return "_echo" + + def get_description(self) -> str: + return "Echo agent for tests" + + def get_tools(self) -> list[Any]: + return [] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + return query + + +# ── Helpers ─────────────────────────────────────────────────────────── + + +def _make_ai_message(content: str = "", tool_calls: list | None = None) -> AIMessage: + msg = AIMessage(content=content) + if tool_calls: + msg.tool_calls = tool_calls + else: + msg.tool_calls = [] + return msg + + +def _make_tool(name: str, return_value: Any) -> MagicMock: + t = MagicMock() + t.name = name + t.ainvoke = AsyncMock(return_value=return_value) + return t + + +def _make_stream_chunks(tokens: list[str]) -> list[MagicMock]: + chunks = [] + for tok in tokens: + c = MagicMock() + c.content = tok + chunks.append(c) + return chunks + + +async def _collect_stream(agent: ChatAgent, llm: Any, messages: list, tools: list) -> list[str]: + tokens: list[str] = [] + async for tok in agent._tool_loop_stream(llm, messages, tools): + tokens.append(tok) + return tokens + + +# ── tool_results initialised ───────────────────────────────────────── + + +def test_tool_results_init(): + agent = _EchoAgent() + assert agent.tool_results == [] + + +# ── _tool_loop: no tool calls ──────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_no_tools(): + agent = _EchoAgent() + llm = AsyncMock() + llm.ainvoke = AsyncMock(return_value=_make_ai_message("Hello!")) + + result = await agent._tool_loop(llm, [HumanMessage(content="hi")], []) + assert result == "Hello!" + assert agent.tool_results == [] + + +# ── _tool_loop: with one tool call + result capture ────────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_captures_tool_results(): + agent = _EchoAgent() + + # Mock execute_on_client to return structured data via the tool + raw_result = {"rows": [{"id": "t-1", "title": "Fix bug", "status": "todo"}]} + + async def fake_executor(payload: dict) -> dict: + return raw_result + + # AIMessage with a tool call, then a final answer + tool_call_msg = _make_ai_message( + tool_calls=[{"name": "list_tasks", "args": {}, "id": "call-1", "type": "tool_call"}] + ) + final_msg = _make_ai_message("Here are your tasks.") + + llm = MagicMock() + llm_with_tools = MagicMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) + llm.ainvoke = AsyncMock(return_value=final_msg) + + mock_tool = _make_tool("list_tasks", "- Fix bug (todo)") + + from app.core.ws_context import set_client_executor, clear_client_executor + set_client_executor(fake_executor) + try: + # Patch the tool to actually call execute_on_client + async def tool_side_effect(args: dict) -> str: + from app.core.ws_context import execute_on_client + res = await execute_on_client(action="select", table="tasks") + rows = res.get("rows", []) + return "\n".join(r["title"] for r in rows) + + mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect) + + result = await agent._tool_loop( + llm, [HumanMessage(content="list my tasks")], [mock_tool] + ) + finally: + clear_client_executor() + + assert result == "Here are your tasks." + assert len(agent.tool_results) == 1 + assert agent.tool_results[0] == raw_result + + +# ── _tool_loop: tool_results reset on each call ────────────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_resets_tool_results(): + agent = _EchoAgent() + agent.tool_results = [{"stale": True}] # pre-populated from a previous call + + llm = AsyncMock() + llm.ainvoke = AsyncMock(return_value=_make_ai_message("Done.")) + + await agent._tool_loop(llm, [HumanMessage(content="hi")], []) + assert agent.tool_results == [] + + +# ── _tool_loop: unknown tool name ──────────────────────────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_unknown_tool(): + agent = _EchoAgent() + + # No known tools — model still calls a non-existent one; loop handles gracefully + tool_call_msg = _make_ai_message( + tool_calls=[{"name": "nonexistent", "args": {}, "id": "c1", "type": "tool_call"}] + ) + final_msg = _make_ai_message("Handled.") + + mock_tool = _make_tool("known", "ok") # a different tool, not "nonexistent" + llm = MagicMock() + llm_with_tools = MagicMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg]) + + result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool]) + assert result == "Handled." + + +# ── _tool_loop: max_iter exhaustion ────────────────────────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_max_iter(): + agent = _EchoAgent() + + always_tool = _make_ai_message( + tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}] + ) + fallback = _make_ai_message("Fallback.") + + llm = MagicMock() + llm_with_tools = MagicMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + # Returns tool_call_msg on every iteration + llm_with_tools.ainvoke = AsyncMock(return_value=always_tool) + llm.ainvoke = AsyncMock(return_value=fallback) + + mock_tool = _make_tool("t", "ok") + + result = await agent._tool_loop(llm, [HumanMessage(content="x")], [mock_tool], max_iter=2) + assert result == "Fallback." + assert llm_with_tools.ainvoke.call_count == 2 + + +# ── _tool_loop_stream: no tool calls — yields tokens ───────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_stream_no_tools_yields_tokens(): + agent = _EchoAgent() + + # No tools → llm used directly; ainvoke returns no tool calls → stream is used + no_tool_msg = _make_ai_message("irrelevant") + llm = AsyncMock() + llm.ainvoke = AsyncMock(return_value=no_tool_msg) + + async def fake_astream(msgs): + for tok in ["Hello", " ", "world"]: + c = MagicMock() + c.content = tok + yield c + + llm.astream = fake_astream + + tokens = await _collect_stream(agent, llm, [HumanMessage(content="hi")], []) + assert tokens == ["Hello", " ", "world"] + assert agent.tool_results == [] + + +# ── _tool_loop_stream: one tool call then streaming final ───────────── + + +@pytest.mark.asyncio +async def test_tool_loop_stream_with_tool_call(): + agent = _EchoAgent() + + raw_result = {"row": {"id": "t-2", "title": "Deploy", "status": "in_progress"}} + + async def fake_executor(payload: dict) -> dict: + return raw_result + + tool_call_msg = _make_ai_message( + tool_calls=[{"name": "get_task", "args": {"id": "t-2"}, "id": "c1", "type": "tool_call"}] + ) + # After tools run, ainvoke returns no more tool calls + no_more_tools_msg = _make_ai_message("Task found.") + + llm = MagicMock() + llm_with_tools = MagicMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg]) + + async def fake_astream(msgs): + for tok in ["Task", " ", "found."]: + c = MagicMock() + c.content = tok + yield c + + llm.astream = fake_astream + + async def tool_side_effect(args: dict) -> str: + from app.core.ws_context import execute_on_client + res = await execute_on_client(action="select", table="tasks", filters={"id": args.get("id")}) + return res.get("row", {}).get("title", "") + + mock_tool = _make_tool("get_task", "Deploy") + mock_tool.ainvoke = AsyncMock(side_effect=tool_side_effect) + + from app.core.ws_context import set_client_executor, clear_client_executor + set_client_executor(fake_executor) + try: + tokens = await _collect_stream( + agent, llm, [HumanMessage(content="get task t-2")], [mock_tool] + ) + finally: + clear_client_executor() + + assert tokens == ["Task", " ", "found."] + assert len(agent.tool_results) == 1 + assert agent.tool_results[0] == raw_result + + +# ── _tool_loop_stream: tool_results reset on each call ─────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_stream_resets_tool_results(): + agent = _EchoAgent() + agent.tool_results = [{"old": True}] + + no_tool_msg = _make_ai_message("") + llm = AsyncMock() + llm.ainvoke = AsyncMock(return_value=no_tool_msg) + + async def fake_astream(msgs): + c = MagicMock() + c.content = "ok" + yield c + + llm.astream = fake_astream + + await _collect_stream(agent, llm, [HumanMessage(content="x")], []) + assert agent.tool_results == [] + + +# ── _tool_loop_stream: empty chunk content is skipped ──────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_stream_skips_empty_chunks(): + agent = _EchoAgent() + no_tool_msg = _make_ai_message("") + + llm = AsyncMock() + llm.ainvoke = AsyncMock(return_value=no_tool_msg) + + async def fake_astream(msgs): + for tok in ["", "hello", "", " world", ""]: + c = MagicMock() + c.content = tok + yield c + + llm.astream = fake_astream + + tokens = await _collect_stream(agent, llm, [HumanMessage(content="x")], []) + assert tokens == ["hello", " world"] + + +# ── _tool_loop_stream: max_iter exhaustion falls back to stream ─────── + + +@pytest.mark.asyncio +async def test_tool_loop_stream_max_iter(): + agent = _EchoAgent() + + always_tool = _make_ai_message( + tool_calls=[{"name": "t", "args": {}, "id": "c1", "type": "tool_call"}] + ) + + llm = MagicMock() + llm_with_tools = MagicMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + llm_with_tools.ainvoke = AsyncMock(return_value=always_tool) + + async def fake_astream(msgs): + c = MagicMock() + c.content = "fallback" + yield c + + llm.astream = fake_astream + mock_tool = _make_tool("t", "ok") + + tokens = await _collect_stream( + agent, llm, [HumanMessage(content="x")], [mock_tool], + ) + assert tokens == ["fallback"] + assert llm_with_tools.ainvoke.call_count == 5 # exhausted default max_iter + + +# ── _tool_loop_stream: multiple tool results captured ──────────────── + + +@pytest.mark.asyncio +async def test_tool_loop_stream_multiple_tool_results(): + agent = _EchoAgent() + + call_results = [ + {"rows": [{"id": "t-1"}]}, + {"rows": [{"id": "t-2"}]}, + ] + call_iter = iter(call_results) + + async def fake_executor(payload: dict) -> dict: + return next(call_iter) + + # Two tool calls in one iteration + tool_call_msg = _make_ai_message( + tool_calls=[ + {"name": "tool_a", "args": {}, "id": "c1", "type": "tool_call"}, + {"name": "tool_b", "args": {}, "id": "c2", "type": "tool_call"}, + ] + ) + no_more_tools_msg = _make_ai_message("Done.") + + llm = MagicMock() + llm_with_tools = MagicMock() + llm.bind_tools = MagicMock(return_value=llm_with_tools) + llm_with_tools.ainvoke = AsyncMock(side_effect=[tool_call_msg, no_more_tools_msg]) + + async def fake_astream(msgs): + c = MagicMock() + c.content = "Done." + yield c + + llm.astream = fake_astream + + async def tool_side_effect(args: dict) -> str: + from app.core.ws_context import execute_on_client + res = await execute_on_client(action="select", table="tasks") + return str(res) + + tool_a = _make_tool("tool_a", "") + tool_a.ainvoke = AsyncMock(side_effect=tool_side_effect) + tool_b = _make_tool("tool_b", "") + tool_b.ainvoke = AsyncMock(side_effect=tool_side_effect) + + from app.core.ws_context import set_client_executor, clear_client_executor + set_client_executor(fake_executor) + try: + tokens = await _collect_stream( + agent, llm, [HumanMessage(content="x")], [tool_a, tool_b] + ) + finally: + clear_client_executor() + + assert tokens == ["Done."] + assert len(agent.tool_results) == 2 + assert agent.tool_results[0] == {"rows": [{"id": "t-1"}]} + assert agent.tool_results[1] == {"rows": [{"id": "t-2"}]}