From 5bc9ea6cd6aac41a09b9328bb2905858196396e7 Mon Sep 17 00:00:00 2001 From: roberto Date: Thu, 12 Mar 2026 23:17:31 +0100 Subject: [PATCH] fix: make planner schema copilot-compatible and silence usage warning --- app/core/deep_agent.py | 103 ++++++++++++++++++++++++++++++++++++----- app/core/llm.py | 9 ++++ 2 files changed, 101 insertions(+), 11 deletions(-) diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py index 9d8f70d..b64624c 100644 --- a/app/core/deep_agent.py +++ b/app/core/deep_agent.py @@ -5,7 +5,6 @@ from __future__ import annotations import asyncio import json import logging -import operator from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any, Literal, TypedDict @@ -116,12 +115,12 @@ WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = { _HOME_ORCHESTRATOR_SYSTEM = ( "You are an orchestrator. Plan which workers should be invoked for the user request. " "Workers: task_agent, project_agent, note_agent, timeline_agent. " - "Return only the workers needed." + "Return JSON only with keys: tasks, floating_domain, memory_updates." ) _FLOATING_ORCHESTRATOR_SYSTEM = ( "You are an orchestrator for floating context. Pick focused workers and set floating_domain " - "as one of: tasks, projects, notes, timelines." + "as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates." ) _HOME_SYNTH_SYSTEM = ( @@ -178,6 +177,78 @@ def _fallback_plan(message: str, floating: bool) -> WorkerPlan: return WorkerPlan(tasks=tasks, floating_domain=domain) +def _extract_json_object(text: str) -> dict[str, Any] | None: + """Best-effort extraction of the first JSON object from model output.""" + stripped = text.strip() + if not stripped: + return None + + # Common case: model returns raw JSON object. + try: + payload = json.loads(stripped) + if isinstance(payload, dict): + return payload + except json.JSONDecodeError: + pass + + # Fenced JSON block fallback. + if "```" in stripped: + parts = stripped.split("```") + for part in parts: + candidate = part.strip() + if candidate.startswith("json"): + candidate = candidate[4:].strip() + try: + payload = json.loads(candidate) + if isinstance(payload, dict): + return payload + except json.JSONDecodeError: + continue + + return None + + +def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan: + """Normalize loose model JSON into a validated WorkerPlan.""" + tasks_raw = payload.get("tasks") + tasks: list[WorkerTask] = [] + + if isinstance(tasks_raw, list): + for item in tasks_raw: + if not isinstance(item, dict): + continue + worker = item.get("worker") + instruction = item.get("instruction") + if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str): + tasks.append(WorkerTask(worker=worker, instruction=instruction)) + + if not tasks: + return _fallback_plan(message, floating) + + domain = payload.get("floating_domain") + floating_domain: FloatingDomain | None = None + if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}: + floating_domain = domain # type: ignore[assignment] + elif floating: + floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"] + + memory_updates: list[MemoryUpdate] = [] + updates_raw = payload.get("memory_updates") + if isinstance(updates_raw, list): + for item in updates_raw: + if isinstance(item, dict): + key = item.get("key") + value = item.get("value") + if isinstance(key, str) and isinstance(value, str) and key and value: + memory_updates.append(MemoryUpdate(key=key, value=value)) + + return WorkerPlan( + tasks=tasks, + floating_domain=floating_domain, + memory_updates=memory_updates, + ) + + async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan: llm = get_llm() system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM @@ -189,18 +260,28 @@ async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) } messages = [ SystemMessage(content=system), - HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)), + HumanMessage( + content=( + "Create a valid JSON object with this exact structure:\n" + '{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],' + '"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n' + "Rules:\n" + "- tasks must include at least one entry when possible\n" + "- use floating_domain only when relevant\n" + "- output JSON only (no markdown, no prose)\n\n" + f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}" + ) + ), ] try: - structured_llm = llm.with_structured_output(WorkerPlan) - plan = await structured_llm.ainvoke(messages) - if isinstance(plan, WorkerPlan): - if not plan.tasks: - return _fallback_plan(message, floating) - return plan + response = await llm.ainvoke(messages) + payload = _extract_json_object(_as_text(response.content)) + if payload is None: + raise ValueError("planner returned non-JSON output") + return _coerce_plan(payload, message, floating) except Exception as exc: - logger.warning("deep_agent: structured planner failed, using fallback: %s", exc) + logger.warning("deep_agent: planner failed, using fallback: %s", exc) return _fallback_plan(message, floating) diff --git a/app/core/llm.py b/app/core/llm.py index 3d985af..3415921 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env`` from __future__ import annotations import os +import warnings from openai import AsyncOpenAI import litellm @@ -32,6 +33,14 @@ from app.config.settings import settings # Drop them silently instead of raising UnsupportedParamsError. litellm.drop_params = True +# Some provider responses include a plain dict in the `usage` field where a +# richer Pydantic model is expected. This warning is noisy but non-fatal. +warnings.filterwarnings( + "ignore", + message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`", + category=UserWarning, +) + def _api_key_for_model(model: str) -> str | None: """Return the most appropriate API key for the given LiteLLM model string."""