fix: make planner schema copilot-compatible and silence usage warning
This commit is contained in:
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import operator
|
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from typing import Any, Literal, TypedDict
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
@@ -116,12 +115,12 @@ WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
|
|||||||
_HOME_ORCHESTRATOR_SYSTEM = (
|
_HOME_ORCHESTRATOR_SYSTEM = (
|
||||||
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
"You are an orchestrator. Plan which workers should be invoked for the user request. "
|
||||||
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
|
||||||
"Return only the workers needed."
|
"Return JSON only with keys: tasks, floating_domain, memory_updates."
|
||||||
)
|
)
|
||||||
|
|
||||||
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
_FLOATING_ORCHESTRATOR_SYSTEM = (
|
||||||
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
|
||||||
"as one of: tasks, projects, notes, timelines."
|
"as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
|
||||||
)
|
)
|
||||||
|
|
||||||
_HOME_SYNTH_SYSTEM = (
|
_HOME_SYNTH_SYSTEM = (
|
||||||
@@ -178,6 +177,78 @@ def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
|
|||||||
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
return WorkerPlan(tasks=tasks, floating_domain=domain)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_object(text: str) -> dict[str, Any] | None:
|
||||||
|
"""Best-effort extraction of the first JSON object from model output."""
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Common case: model returns raw JSON object.
|
||||||
|
try:
|
||||||
|
payload = json.loads(stripped)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
return payload
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fenced JSON block fallback.
|
||||||
|
if "```" in stripped:
|
||||||
|
parts = stripped.split("```")
|
||||||
|
for part in parts:
|
||||||
|
candidate = part.strip()
|
||||||
|
if candidate.startswith("json"):
|
||||||
|
candidate = candidate[4:].strip()
|
||||||
|
try:
|
||||||
|
payload = json.loads(candidate)
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
return payload
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
|
||||||
|
"""Normalize loose model JSON into a validated WorkerPlan."""
|
||||||
|
tasks_raw = payload.get("tasks")
|
||||||
|
tasks: list[WorkerTask] = []
|
||||||
|
|
||||||
|
if isinstance(tasks_raw, list):
|
||||||
|
for item in tasks_raw:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
worker = item.get("worker")
|
||||||
|
instruction = item.get("instruction")
|
||||||
|
if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
|
||||||
|
tasks.append(WorkerTask(worker=worker, instruction=instruction))
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
return _fallback_plan(message, floating)
|
||||||
|
|
||||||
|
domain = payload.get("floating_domain")
|
||||||
|
floating_domain: FloatingDomain | None = None
|
||||||
|
if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
|
||||||
|
floating_domain = domain # type: ignore[assignment]
|
||||||
|
elif floating:
|
||||||
|
floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
|
||||||
|
|
||||||
|
memory_updates: list[MemoryUpdate] = []
|
||||||
|
updates_raw = payload.get("memory_updates")
|
||||||
|
if isinstance(updates_raw, list):
|
||||||
|
for item in updates_raw:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
key = item.get("key")
|
||||||
|
value = item.get("value")
|
||||||
|
if isinstance(key, str) and isinstance(value, str) and key and value:
|
||||||
|
memory_updates.append(MemoryUpdate(key=key, value=value))
|
||||||
|
|
||||||
|
return WorkerPlan(
|
||||||
|
tasks=tasks,
|
||||||
|
floating_domain=floating_domain,
|
||||||
|
memory_updates=memory_updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
|
||||||
llm = get_llm()
|
llm = get_llm()
|
||||||
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
|
||||||
@@ -189,18 +260,28 @@ async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool)
|
|||||||
}
|
}
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=system),
|
SystemMessage(content=system),
|
||||||
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
"Create a valid JSON object with this exact structure:\n"
|
||||||
|
'{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
|
||||||
|
'"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- tasks must include at least one entry when possible\n"
|
||||||
|
"- use floating_domain only when relevant\n"
|
||||||
|
"- output JSON only (no markdown, no prose)\n\n"
|
||||||
|
f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
|
||||||
|
)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
structured_llm = llm.with_structured_output(WorkerPlan)
|
response = await llm.ainvoke(messages)
|
||||||
plan = await structured_llm.ainvoke(messages)
|
payload = _extract_json_object(_as_text(response.content))
|
||||||
if isinstance(plan, WorkerPlan):
|
if payload is None:
|
||||||
if not plan.tasks:
|
raise ValueError("planner returned non-JSON output")
|
||||||
return _fallback_plan(message, floating)
|
return _coerce_plan(payload, message, floating)
|
||||||
return plan
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
|
logger.warning("deep_agent: planner failed, using fallback: %s", exc)
|
||||||
|
|
||||||
return _fallback_plan(message, floating)
|
return _fallback_plan(message, floating)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +33,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
|||||||
Reference in New Issue
Block a user