fix: make planner schema copilot-compatible and silence usage warning

This commit is contained in:
2026-03-12 23:17:31 +01:00
parent f7404b6f66
commit 5bc9ea6cd6
2 changed files with 101 additions and 11 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import operator
from collections.abc import AsyncGenerator, Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Literal, TypedDict from typing import Any, Literal, TypedDict
@@ -116,12 +115,12 @@ WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
_HOME_ORCHESTRATOR_SYSTEM = ( _HOME_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator. Plan which workers should be invoked for the user request. " "You are an orchestrator. Plan which workers should be invoked for the user request. "
"Workers: task_agent, project_agent, note_agent, timeline_agent. " "Workers: task_agent, project_agent, note_agent, timeline_agent. "
"Return only the workers needed." "Return JSON only with keys: tasks, floating_domain, memory_updates."
) )
_FLOATING_ORCHESTRATOR_SYSTEM = ( _FLOATING_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator for floating context. Pick focused workers and set floating_domain " "You are an orchestrator for floating context. Pick focused workers and set floating_domain "
"as one of: tasks, projects, notes, timelines." "as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
) )
_HOME_SYNTH_SYSTEM = ( _HOME_SYNTH_SYSTEM = (
@@ -178,6 +177,78 @@ def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
return WorkerPlan(tasks=tasks, floating_domain=domain) return WorkerPlan(tasks=tasks, floating_domain=domain)
def _extract_json_object(text: str) -> dict[str, Any] | None:
"""Best-effort extraction of the first JSON object from model output."""
stripped = text.strip()
if not stripped:
return None
# Common case: model returns raw JSON object.
try:
payload = json.loads(stripped)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
pass
# Fenced JSON block fallback.
if "```" in stripped:
parts = stripped.split("```")
for part in parts:
candidate = part.strip()
if candidate.startswith("json"):
candidate = candidate[4:].strip()
try:
payload = json.loads(candidate)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
continue
return None
def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
"""Normalize loose model JSON into a validated WorkerPlan."""
tasks_raw = payload.get("tasks")
tasks: list[WorkerTask] = []
if isinstance(tasks_raw, list):
for item in tasks_raw:
if not isinstance(item, dict):
continue
worker = item.get("worker")
instruction = item.get("instruction")
if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
tasks.append(WorkerTask(worker=worker, instruction=instruction))
if not tasks:
return _fallback_plan(message, floating)
domain = payload.get("floating_domain")
floating_domain: FloatingDomain | None = None
if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
floating_domain = domain # type: ignore[assignment]
elif floating:
floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
memory_updates: list[MemoryUpdate] = []
updates_raw = payload.get("memory_updates")
if isinstance(updates_raw, list):
for item in updates_raw:
if isinstance(item, dict):
key = item.get("key")
value = item.get("value")
if isinstance(key, str) and isinstance(value, str) and key and value:
memory_updates.append(MemoryUpdate(key=key, value=value))
return WorkerPlan(
tasks=tasks,
floating_domain=floating_domain,
memory_updates=memory_updates,
)
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan: async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
llm = get_llm() llm = get_llm()
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
@@ -189,18 +260,28 @@ async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool)
} }
messages = [ messages = [
SystemMessage(content=system), SystemMessage(content=system),
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)), HumanMessage(
content=(
"Create a valid JSON object with this exact structure:\n"
'{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
'"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
"Rules:\n"
"- tasks must include at least one entry when possible\n"
"- use floating_domain only when relevant\n"
"- output JSON only (no markdown, no prose)\n\n"
f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
)
),
] ]
try: try:
structured_llm = llm.with_structured_output(WorkerPlan) response = await llm.ainvoke(messages)
plan = await structured_llm.ainvoke(messages) payload = _extract_json_object(_as_text(response.content))
if isinstance(plan, WorkerPlan): if payload is None:
if not plan.tasks: raise ValueError("planner returned non-JSON output")
return _fallback_plan(message, floating) return _coerce_plan(payload, message, floating)
return plan
except Exception as exc: except Exception as exc:
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc) logger.warning("deep_agent: planner failed, using fallback: %s", exc)
return _fallback_plan(message, floating) return _fallback_plan(message, floating)

View File

@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from openai import AsyncOpenAI from openai import AsyncOpenAI
import litellm import litellm
@@ -32,6 +33,14 @@ from app.config.settings import settings
# Drop them silently instead of raising UnsupportedParamsError. # Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None: def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string.""" """Return the most appropriate API key for the given LiteLLM model string."""