diff --git a/BACKEND_PLAN.md b/BACKEND_PLAN.md index 8424e3c..53a5200 100644 --- a/BACKEND_PLAN.md +++ b/BACKEND_PLAN.md @@ -181,8 +181,8 @@ adiuva-api/ - [x] Integration tests with mocked LLM and mocked agents - **Outcome:** Intelligent routing with single-agent and pipeline modes. -### Step 5 — Execution Plan generator -- [ ] `app/core/execution_plan.py`: +### Step 5 — Execution Plan generator ✅ +- [x] `app/core/execution_plan.py`: - `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs. - `ExecutionPlanBuilder`: - `add_step(action, params) -> self` diff --git a/app/core/execution_plan.py b/app/core/execution_plan.py new file mode 100644 index 0000000..a6edd3a --- /dev/null +++ b/app/core/execution_plan.py @@ -0,0 +1,218 @@ +"""Execution Plan generator — builder, template registry, and LRU plan cache.""" + +from __future__ import annotations + +from collections import OrderedDict +from typing import Any + +from app.schemas import ExecutionPlan, PlanStep + + +# ── Prompt Template Registry ────────────────────────────────────────── + + +class PromptTemplateRegistry: + """Server-side store mapping template IDs to prompt text. + + Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``). + The actual prompt text is resolved here on the server, keeping prompt IP + out of API responses. + """ + + def __init__(self) -> None: + self._templates: dict[str, str] = {} + + def register(self, template_id: str, prompt_text: str) -> None: + self._templates[template_id] = prompt_text + + def get(self, template_id: str) -> str: + """Resolve a template ID to its prompt text. + + Raises ``KeyError`` if the template is not registered. + """ + text = self._templates.get(template_id) + if text is None: + raise KeyError(f"Template not found: {template_id!r}") + return text + + def has(self, template_id: str) -> bool: + return template_id in self._templates + + def list_ids(self) -> list[str]: + """Return all registered template IDs (never the text).""" + return list(self._templates.keys()) + + +# ── Execution Plan Builder ──────────────────────────────────────────── + + +class ExecutionPlanBuilder: + """Fluent builder for ``ExecutionPlan`` objects. + + Example:: + + plan = ( + ExecutionPlanBuilder("task_agent") + .add_llm_step("tpl_task_agent_default", {"message": user_msg}) + .add_data_step("create_record", data_from_step=0) + .build() + ) + """ + + def __init__(self, agent: str) -> None: + self._agent = agent + self._steps: list[PlanStep] = [] + + # ── step adders ────────────────────────────────────────────────── + + def add_step( + self, action: str, params: dict[str, Any] | None = None + ) -> ExecutionPlanBuilder: + """Append a generic action step with optional parameters.""" + self._steps.append(PlanStep(action=action, variables=params)) + return self + + def add_llm_step( + self, template_id: str, variables: dict[str, Any] | None = None + ) -> ExecutionPlanBuilder: + """Append an LLM step referencing a server-side template by ID.""" + self._steps.append( + PlanStep(action="llm", prompt_template=template_id, variables=variables) + ) + return self + + def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder: + """Append a step whose input comes from the output of an earlier step.""" + self._steps.append(PlanStep(action=action, data_from_step=data_from_step)) + return self + + # ── build ──────────────────────────────────────────────────────── + + def build(self) -> ExecutionPlan: + """Validate step references and return the ``ExecutionPlan``. + + Raises ``ValueError`` if any ``data_from_step`` references a + non-existent or future step index. + """ + for i, step in enumerate(self._steps): + if step.data_from_step is not None: + if not (0 <= step.data_from_step < i): + raise ValueError( + f"Step {i}: data_from_step={step.data_from_step} must " + f"reference a preceding step index in range 0..{i - 1}" + ) + return ExecutionPlan(agent=self._agent, steps=list(self._steps)) + + +# ── Plan Cache (LRU) ────────────────────────────────────────────────── + + +class PlanCache: + """In-memory LRU cache for ``ExecutionPlan`` objects. + + Plans stored here are accessible as playbooks via ``get_all_playbooks()``. + The cache also serves as a runtime memoisation layer so that repeated + identical intent classifications can skip re-building the plan. + """ + + def __init__(self, maxsize: int = 1000) -> None: + self._maxsize = maxsize + self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict() + + def cache_plan(self, key: str, plan: ExecutionPlan) -> None: + """Store *plan* under *key*, evicting the LRU entry if at capacity.""" + if key in self._cache: + del self._cache[key] # remove so re-insertion places it at the end + elif len(self._cache) >= self._maxsize: + self._cache.popitem(last=False) # evict least-recently-used + self._cache[key] = plan + + def get_plan(self, key: str) -> ExecutionPlan | None: + """Return the cached plan for *key*, or ``None`` if not present. + + Accessing a plan marks it as most-recently used. + """ + if key not in self._cache: + return None + self._cache.move_to_end(key) + return self._cache[key] + + def get_all_playbooks(self) -> list[ExecutionPlan]: + """Return all cached plans (most-recently used last).""" + return list(self._cache.values()) + + +# ── Module-level singletons ─────────────────────────────────────────── + +template_registry = PromptTemplateRegistry() +plan_cache = PlanCache() + + +def _register_builtin_templates() -> None: + """Register the built-in server-side prompt templates. + + These strings never leave the server. Clients only receive the IDs. + """ + _tpls: dict[str, str] = { + "tpl_task_agent_default": ( + "You are a task management assistant. Help the user create, update, " + "and prioritize tasks based on their message and context." + ), + "tpl_calendar_agent_default": ( + "You are a calendar assistant. Help manage events, detect scheduling " + "conflicts, and suggest improvements based on the provided context." + ), + "tpl_email_agent_default": ( + "You are an email analysis assistant. Classify emails, extract action " + "items, and draft responses using only the metadata provided." + ), + "tpl_analytics_agent_default": ( + "You are a workspace analytics assistant. Calculate metrics, generate " + "reports, and surface trends from the data provided in context." + ), + "tpl_email_extract_action_items": ( + "Extract all action items from the provided email metadata. " + "Return a structured list of tasks, each with a title, inferred " + "priority, and suggested due date where possible." + ), + "tpl_analytics_weekly_summary": ( + "Generate a weekly performance summary from the provided analytics " + "data. Include task completion rate, overdue item count, top " + "priorities for the coming week, and notable trends." + ), + } + for tid, text in _tpls.items(): + template_registry.register(tid, text) + + +def _load_playbooks() -> None: + """Pre-build and cache the built-in playbooks.""" + playbooks: list[tuple[str, ExecutionPlan]] = [ + ( + "create_task_from_email", + ExecutionPlanBuilder("email_agent") + .add_llm_step( + "tpl_email_extract_action_items", + {"source": "email_metadata"}, + ) + .add_data_step("create_record", data_from_step=0) + .build(), + ), + ( + "generate_weekly_report", + ExecutionPlanBuilder("analytics_agent") + .add_llm_step( + "tpl_analytics_weekly_summary", + {"period": "last_7_days"}, + ) + .add_data_step("create_record", data_from_step=0) + .build(), + ), + ] + for key, plan in playbooks: + plan_cache.cache_plan(key, plan) + + +# Initialise on module load +_register_builtin_templates() +_load_playbooks() diff --git a/app/core/orchestrator.py b/app/core/orchestrator.py index 82e8f6c..77d7d9f 100644 --- a/app/core/orchestrator.py +++ b/app/core/orchestrator.py @@ -11,7 +11,7 @@ from langchain_openai import ChatOpenAI from app.config.settings import settings from app.core.agent_registry import AgentRegistry from app.core.agent_registry import registry as _default_registry -from app.schemas import ChatRequest, ChatResponse, ExecutionPlan, PlanStep +from app.schemas import ChatRequest, ChatResponse, ExecutionPlan _FALLBACK_AGENT = "task_agent" @@ -99,22 +99,21 @@ async def route_pipeline( def _build_plan(agent_name: str, message: str) -> ExecutionPlan: - """Build a minimal ``ExecutionPlan`` for the resolved agent. + """Build an ``ExecutionPlan`` for the resolved agent. - The full ``ExecutionPlanBuilder`` (with template registry and caching) is - implemented in Step 5. This function produces the single-step baseline - plan that the orchestrator returns in ``'plan'`` mode. + Uses ``ExecutionPlanBuilder`` with the server-side template registry. + If a default template exists for the agent, an LLM step is emitted; + otherwise a plain ``handle`` action step is used. """ - return ExecutionPlan( - agent=agent_name, - steps=[ - PlanStep( - action="handle", - prompt_template=f"tpl_{agent_name}_default", - variables={"message": message}, - ) - ], - ) + from app.core.execution_plan import ExecutionPlanBuilder, template_registry + + template_id = f"tpl_{agent_name}_default" + builder = ExecutionPlanBuilder(agent_name) + if template_registry.has(template_id): + builder.add_llm_step(template_id, {"message": message}) + else: + builder.add_step("handle", {"message": message}) + return builder.build() async def orchestrate( diff --git a/tests/test_execution_plan.py b/tests/test_execution_plan.py new file mode 100644 index 0000000..03e2db7 --- /dev/null +++ b/tests/test_execution_plan.py @@ -0,0 +1,286 @@ +"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache.""" + +from __future__ import annotations + +import pytest + +from app.core.execution_plan import ( + ExecutionPlanBuilder, + PlanCache, + PromptTemplateRegistry, + plan_cache, + template_registry, +) +from app.schemas import ExecutionPlan + + +# ── PromptTemplateRegistry ──────────────────────────────────────────── + + +class TestPromptTemplateRegistry: + def test_register_and_get(self) -> None: + reg = PromptTemplateRegistry() + reg.register("tpl_foo", "You are a foo agent.") + assert reg.get("tpl_foo") == "You are a foo agent." + + def test_get_unknown_raises_key_error(self) -> None: + reg = PromptTemplateRegistry() + with pytest.raises(KeyError, match="tpl_missing"): + reg.get("tpl_missing") + + def test_has_returns_true_for_registered(self) -> None: + reg = PromptTemplateRegistry() + reg.register("tpl_x", "prompt text") + assert reg.has("tpl_x") is True + + def test_has_returns_false_for_unregistered(self) -> None: + reg = PromptTemplateRegistry() + assert reg.has("tpl_missing") is False + + def test_list_ids_returns_all_registered_ids(self) -> None: + reg = PromptTemplateRegistry() + reg.register("tpl_a", "a") + reg.register("tpl_b", "b") + assert set(reg.list_ids()) == {"tpl_a", "tpl_b"} + + def test_list_ids_does_not_return_prompt_text(self) -> None: + reg = PromptTemplateRegistry() + reg.register("tpl_secret", "top secret prompt") + ids = reg.list_ids() + assert "top secret prompt" not in ids + + def test_overwrite_existing_template(self) -> None: + reg = PromptTemplateRegistry() + reg.register("tpl_x", "v1") + reg.register("tpl_x", "v2") + assert reg.get("tpl_x") == "v2" + + def test_empty_registry_has_no_ids(self) -> None: + reg = PromptTemplateRegistry() + assert reg.list_ids() == [] + + +# ── ExecutionPlanBuilder ────────────────────────────────────────────── + + +class TestExecutionPlanBuilder: + def test_builds_empty_plan(self) -> None: + plan = ExecutionPlanBuilder("task_agent").build() + assert plan.agent == "task_agent" + assert plan.steps == [] + + def test_add_step_basic(self) -> None: + plan = ( + ExecutionPlanBuilder("task_agent") + .add_step("create_task", {"priority": "high"}) + .build() + ) + assert len(plan.steps) == 1 + assert plan.steps[0].action == "create_task" + assert plan.steps[0].variables == {"priority": "high"} + assert plan.steps[0].prompt_template is None + assert plan.steps[0].data_from_step is None + + def test_add_step_no_params(self) -> None: + plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build() + assert plan.steps[0].variables is None + + def test_add_llm_step(self) -> None: + plan = ( + ExecutionPlanBuilder("task_agent") + .add_llm_step("tpl_task_default", {"message": "hi"}) + .build() + ) + assert plan.steps[0].action == "llm" + assert plan.steps[0].prompt_template == "tpl_task_default" + assert plan.steps[0].variables == {"message": "hi"} + + def test_add_llm_step_no_variables(self) -> None: + plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build() + assert plan.steps[0].variables is None + + def test_add_data_step(self) -> None: + plan = ( + ExecutionPlanBuilder("task_agent") + .add_step("fetch_data") + .add_data_step("transform", data_from_step=0) + .build() + ) + assert plan.steps[1].action == "transform" + assert plan.steps[1].data_from_step == 0 + + def test_fluent_chaining_returns_builder(self) -> None: + builder = ExecutionPlanBuilder("analytics_agent") + result = builder.add_step("a") + assert result is builder + + def test_fluent_chain_multiple_steps(self) -> None: + plan = ( + ExecutionPlanBuilder("analytics_agent") + .add_llm_step("tpl_analytics_default") + .add_step("format_output") + .add_data_step("store", data_from_step=0) + .build() + ) + assert len(plan.steps) == 3 + + def test_build_validates_data_from_step_out_of_range(self) -> None: + with pytest.raises(ValueError, match="data_from_step"): + ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build() + + def test_build_validates_data_from_step_self_reference(self) -> None: + """data_from_step=0 on the first step (index 0) is invalid.""" + with pytest.raises(ValueError, match="data_from_step"): + ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build() + + def test_build_validates_data_from_step_negative(self) -> None: + with pytest.raises(ValueError, match="data_from_step"): + ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build() + + def test_valid_data_from_step_at_index_two(self) -> None: + plan = ( + ExecutionPlanBuilder("task_agent") + .add_step("step0") + .add_step("step1") + .add_data_step("step2", data_from_step=1) + .build() + ) + assert plan.steps[2].data_from_step == 1 + + def test_data_from_step_zero_valid_at_index_one(self) -> None: + plan = ( + ExecutionPlanBuilder("task_agent") + .add_step("step0") + .add_data_step("step1", data_from_step=0) + .build() + ) + assert plan.steps[1].data_from_step == 0 + + def test_build_returns_new_plan_each_call(self) -> None: + builder = ExecutionPlanBuilder("task_agent").add_step("do_thing") + plan1 = builder.build() + plan2 = builder.build() + assert plan1 is not plan2 + assert plan1.steps == plan2.steps + + def test_plan_is_execution_plan_instance(self) -> None: + plan = ExecutionPlanBuilder("task_agent").build() + assert isinstance(plan, ExecutionPlan) + + +# ── PlanCache ───────────────────────────────────────────────────────── + + +class TestPlanCache: + def _plan(self, agent: str = "a") -> ExecutionPlan: + return ExecutionPlanBuilder(agent).build() + + def test_cache_and_get(self) -> None: + cache = PlanCache() + plan = self._plan() + cache.cache_plan("key1", plan) + assert cache.get_plan("key1") is plan + + def test_get_missing_returns_none(self) -> None: + cache = PlanCache() + assert cache.get_plan("nonexistent") is None + + def test_get_all_playbooks_empty(self) -> None: + cache = PlanCache() + assert cache.get_all_playbooks() == [] + + def test_get_all_playbooks_returns_all_stored(self) -> None: + cache = PlanCache() + p1, p2 = self._plan("a"), self._plan("b") + cache.cache_plan("k1", p1) + cache.cache_plan("k2", p2) + playbooks = cache.get_all_playbooks() + assert len(playbooks) == 2 + assert p1 in playbooks + assert p2 in playbooks + + def test_lru_evicts_oldest_entry(self) -> None: + cache = PlanCache(maxsize=2) + p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c") + cache.cache_plan("k1", p1) + cache.cache_plan("k2", p2) + cache.cache_plan("k3", p3) # k1 should be evicted + assert cache.get_plan("k1") is None + assert cache.get_plan("k2") is p2 + assert cache.get_plan("k3") is p3 + + def test_lru_access_updates_recency(self) -> None: + cache = PlanCache(maxsize=2) + p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c") + cache.cache_plan("k1", p1) + cache.cache_plan("k2", p2) + cache.get_plan("k1") # k1 is now most-recently used + cache.cache_plan("k3", p3) # k2 should be evicted (LRU) + assert cache.get_plan("k1") is p1 + assert cache.get_plan("k2") is None + assert cache.get_plan("k3") is p3 + + def test_overwrite_existing_key(self) -> None: + cache = PlanCache() + p1, p2 = self._plan("a"), self._plan("b") + cache.cache_plan("same_key", p1) + cache.cache_plan("same_key", p2) + assert cache.get_plan("same_key") is p2 + assert len(cache.get_all_playbooks()) == 1 + + def test_overwrite_does_not_consume_capacity(self) -> None: + cache = PlanCache(maxsize=2) + p1, p2 = self._plan("a"), self._plan("b") + cache.cache_plan("k1", p1) + cache.cache_plan("k1", p2) # overwrite, not a new slot + cache.cache_plan("k2", p1) # should fit without eviction + assert cache.get_plan("k1") is p2 + assert cache.get_plan("k2") is p1 + + +# ── Module-level singletons ─────────────────────────────────────────── + + +class TestModuleSingletons: + def test_template_registry_has_all_agent_defaults(self) -> None: + for agent in ("task_agent", "calendar_agent", "email_agent", "analytics_agent"): + assert template_registry.has(f"tpl_{agent}_default"), ( + f"Missing template: tpl_{agent}_default" + ) + + def test_template_registry_has_operation_templates(self) -> None: + assert template_registry.has("tpl_email_extract_action_items") + assert template_registry.has("tpl_analytics_weekly_summary") + + def test_template_registry_get_returns_non_empty_string(self) -> None: + text = template_registry.get("tpl_task_agent_default") + assert isinstance(text, str) + assert len(text) > 0 + + def test_plan_cache_has_prebuilt_playbooks(self) -> None: + assert len(plan_cache.get_all_playbooks()) >= 2 + + def test_playbook_create_task_from_email(self) -> None: + plan = plan_cache.get_plan("create_task_from_email") + assert plan is not None + assert plan.agent == "email_agent" + assert len(plan.steps) == 2 + assert plan.steps[0].prompt_template == "tpl_email_extract_action_items" + assert plan.steps[1].data_from_step == 0 + + def test_playbook_generate_weekly_report(self) -> None: + plan = plan_cache.get_plan("generate_weekly_report") + assert plan is not None + assert plan.agent == "analytics_agent" + assert len(plan.steps) == 2 + assert plan.steps[0].prompt_template == "tpl_analytics_weekly_summary" + assert plan.steps[1].data_from_step == 0 + + def test_playbook_steps_have_no_raw_prompt_text(self) -> None: + """Plans must not embed prompt text — only template IDs.""" + for plan in plan_cache.get_all_playbooks(): + for step in plan.steps: + if step.prompt_template is not None: + assert step.prompt_template.startswith("tpl_"), ( + f"prompt_template looks like raw text: {step.prompt_template!r}" + )