2 Commits

Author SHA1 Message Date
5b55f1292a make a single agent 2026-03-13 07:42:36 +01:00
5bc9ea6cd6 fix: make planner schema copilot-compatible and silence usage warning 2026-03-12 23:17:31 +01:00
5 changed files with 482 additions and 23 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
@@ -9,6 +10,14 @@ from langchain_core.tools import tool
from app.core.llm import embed
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
@@ -19,6 +28,7 @@ NOTE_SYSTEM_PROMPT = (
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@@ -27,10 +37,11 @@ NOTE_SYSTEM_PROMPT = (
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": project_id or None},
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:

View File

@@ -3,12 +3,21 @@
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
@@ -39,11 +48,12 @@ async def list_tasks(
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": project_id or None,
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
@@ -205,8 +215,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result["row"]
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
row = result.get("row", {})
row_author = row.get("author", author)
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool

View File

@@ -2,17 +2,27 @@
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
@@ -25,10 +35,11 @@ TIMELINE_SYSTEM_PROMPT = (
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": project_id or None},
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:

View File

@@ -5,9 +5,10 @@ from __future__ import annotations
import asyncio
import json
import logging
import operator
import re
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Literal, TypedDict
import operator
from typing import Annotated, Any, Literal, TypedDict
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
@@ -22,11 +23,14 @@ from app.agents.task_agent import TASK_SYSTEM_PROMPT, TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_SYSTEM_PROMPT, TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, set_tool_result_collector
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
logger = logging.getLogger(__name__)
# Quick test switch: home requests run as one agent with all tools.
HOME_SINGLE_AGENT_TEST_MODE = True
WorkerName = Literal["task_agent", "project_agent", "note_agent", "timeline_agent"]
FloatingDomain = Literal["tasks", "projects", "notes", "timelines"]
@@ -56,6 +60,7 @@ class WorkerResult(TypedDict):
instruction: str
response: str
entity_ids: dict[str, list[str]]
facts: dict[str, Any]
class OrchestratorState(TypedDict, total=False):
@@ -71,7 +76,7 @@ class OrchestratorState(TypedDict, total=False):
class GraphState(OrchestratorState):
worker_results: list[WorkerResult]
worker_results: Annotated[list[WorkerResult], operator.add]
class ReducerState(OrchestratorState):
@@ -116,19 +121,21 @@ WORKER_CONFIG: dict[WorkerName, dict[str, Any]] = {
_HOME_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator. Plan which workers should be invoked for the user request. "
"Workers: task_agent, project_agent, note_agent, timeline_agent. "
"Return only the workers needed."
"Return JSON only with keys: tasks, floating_domain, memory_updates."
)
_FLOATING_ORCHESTRATOR_SYSTEM = (
"You are an orchestrator for floating context. Pick focused workers and set floating_domain "
"as one of: tasks, projects, notes, timelines."
"as one of: tasks, projects, notes, timelines. Return JSON only with keys: tasks, floating_domain, memory_updates."
)
_HOME_SYNTH_SYSTEM = (
"You are the final response synthesizer. Return markdown only. "
"Embed inline component tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, and <chart>{json}</chart>. "
"Only include IDs that are truly relevant to the request."
"Only include IDs that are truly relevant to the request. "
"Never invent missing values. If facts include a non-null clientId for a project, "
"do not claim that the project has no owner/client."
)
_FLOATING_SYNTH_SYSTEM = (
@@ -136,6 +143,14 @@ _FLOATING_SYNTH_SYSTEM = (
"Return concise markdown and stay focused on the requested scope."
)
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines. "
"Always use tools for factual data retrieval before answering. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and embed inline tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>."
)
def _as_text(content: Any) -> str:
if content is None:
@@ -178,7 +193,243 @@ def _fallback_plan(message: str, floating: bool) -> WorkerPlan:
return WorkerPlan(tasks=tasks, floating_domain=domain)
def _extract_json_object(text: str) -> dict[str, Any] | None:
"""Best-effort extraction of the first JSON object from model output."""
stripped = text.strip()
if not stripped:
return None
# Common case: model returns raw JSON object.
try:
payload = json.loads(stripped)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
pass
# Fenced JSON block fallback.
if "```" in stripped:
parts = stripped.split("```")
for part in parts:
candidate = part.strip()
if candidate.startswith("json"):
candidate = candidate[4:].strip()
try:
payload = json.loads(candidate)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
continue
return None
def _coerce_plan(payload: dict[str, Any], message: str, floating: bool) -> WorkerPlan:
"""Normalize loose model JSON into a validated WorkerPlan."""
tasks_raw = payload.get("tasks")
tasks: list[WorkerTask] = []
if isinstance(tasks_raw, list):
for item in tasks_raw:
if not isinstance(item, dict):
continue
worker = item.get("worker")
instruction = item.get("instruction")
if isinstance(worker, str) and worker in WORKER_CONFIG and isinstance(instruction, str):
tasks.append(WorkerTask(worker=worker, instruction=instruction))
if not tasks:
return _fallback_plan(message, floating)
domain = payload.get("floating_domain")
floating_domain: FloatingDomain | None = None
if isinstance(domain, str) and domain in {"tasks", "projects", "notes", "timelines"}:
floating_domain = domain # type: ignore[assignment]
elif floating:
floating_domain = WORKER_CONFIG[tasks[0].worker]["floating_domain"]
memory_updates: list[MemoryUpdate] = []
updates_raw = payload.get("memory_updates")
if isinstance(updates_raw, list):
for item in updates_raw:
if isinstance(item, dict):
key = item.get("key")
value = item.get("value")
if isinstance(key, str) and isinstance(value, str) and key and value:
memory_updates.append(MemoryUpdate(key=key, value=value))
return WorkerPlan(
tasks=tasks,
floating_domain=floating_domain,
memory_updates=memory_updates,
)
def _needs_full_project_snapshot(message: str, floating: bool) -> bool:
"""Detect project status/update requests that should query all workers."""
if floating:
return False
lowered = message.lower()
has_project = any(k in lowered for k in ["project", "progetto", "progetto", "progetti", "progetto", "whitelist"])
has_status_intent = any(k in lowered for k in ["status", "stato", "aggiorn", "update", "situazione", "riepilogo", "summary"])
return has_project and has_status_intent
def _build_full_project_snapshot_plan(message: str) -> WorkerPlan:
"""Build a deterministic all-workers plan for project status snapshots."""
project_hint = (
"Use context.context.resolved_project_id when present as project_id. "
"Do not pass project names as project_id."
)
return WorkerPlan(
tasks=[
WorkerTask(worker="project_agent", instruction=f"Resolve the target project from this request and return core fields including id, name, status, clientId. {project_hint} Request: {message}"),
WorkerTask(worker="task_agent", instruction=f"Collect tasks relevant to the project in this request; include pending/blocked highlights and IDs. {project_hint} Request: {message}"),
WorkerTask(worker="timeline_agent", instruction=f"Collect timeline/milestone items relevant to the project in this request; include upcoming items and IDs. {project_hint} Request: {message}"),
WorkerTask(worker="note_agent", instruction=f"Collect notes relevant to the project in this request; include latest useful notes and IDs. {project_hint} Request: {message}"),
]
)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [t for t in tokens if len(t) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
async def _prepare_home_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
"""Resolve and inject project_id hints for home flows."""
prepared = dict(context)
if _needs_full_project_snapshot(message, floating=False):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, message[:200])
return prepared
def _all_tools() -> list[Any]:
tools: list[Any] = []
for config in WORKER_CONFIG.values():
tools.extend(config["tools"])
return tools
async def _run_home_single_agent(
user_id: str,
message: str,
context: dict[str, Any],
) -> str:
"""Single-agent test mode: one loop with all tools."""
prepared_context = await _prepare_home_context(message, context)
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
]
for _ in range(6):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
async def _run_home_single_agent_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
"""Streaming variant for single-agent home test mode."""
prepared_context = await _prepare_home_context(message, context)
llm = get_llm()
tools = _all_tools()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=_HOME_SINGLE_AGENT_SYSTEM),
HumanMessage(content=f"User message:\n{message}\n\nContext:\n{json.dumps({'context': prepared_context}, ensure_ascii=True)[:3500]}"),
]
for _ in range(6):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
return
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
yield "token", token
async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool) -> WorkerPlan:
if _needs_full_project_snapshot(message, floating):
logger.info("deep_agent: forcing full project snapshot plan for message=%s", message[:200])
return _build_full_project_snapshot_plan(message)
llm = get_llm()
system = _FLOATING_ORCHESTRATOR_SYSTEM if floating else _HOME_ORCHESTRATOR_SYSTEM
@@ -189,18 +440,34 @@ async def _plan_with_llm(message: str, context: dict[str, Any], floating: bool)
}
messages = [
SystemMessage(content=system),
HumanMessage(content=json.dumps(prompt_payload, ensure_ascii=True)),
HumanMessage(
content=(
"Create a valid JSON object with this exact structure:\n"
'{"tasks":[{"worker":"task_agent|project_agent|note_agent|timeline_agent","instruction":"..."}],'
'"floating_domain":"tasks|projects|notes|timelines|null","memory_updates":[{"key":"...","value":"..."}]}\n\n'
"Rules:\n"
"- tasks must include at least one entry when possible\n"
"- use floating_domain only when relevant\n"
"- output JSON only (no markdown, no prose)\n\n"
f"Input:\n{json.dumps(prompt_payload, ensure_ascii=True)}"
)
),
]
try:
structured_llm = llm.with_structured_output(WorkerPlan)
plan = await structured_llm.ainvoke(messages)
if isinstance(plan, WorkerPlan):
if not plan.tasks:
return _fallback_plan(message, floating)
return plan
response = await llm.ainvoke(messages)
payload = _extract_json_object(_as_text(response.content))
if payload is None:
raise ValueError("planner returned non-JSON output")
plan = _coerce_plan(payload, message, floating)
logger.info(
"deep_agent: planner produced tasks=%s floating=%s",
[t.worker for t in plan.tasks],
plan.floating_domain,
)
return plan
except Exception as exc:
logger.warning("deep_agent: structured planner failed, using fallback: %s", exc)
logger.warning("deep_agent: planner failed, using fallback: %s", exc)
return _fallback_plan(message, floating)
@@ -243,6 +510,64 @@ def _extract_entity_ids(tool_results: list[dict[str, Any]]) -> dict[str, list[st
return out
def _extract_facts(tool_results: list[dict[str, Any]]) -> dict[str, Any]:
"""Extract small, structured facts for the synthesizer to avoid hallucinations."""
facts: dict[str, Any] = {"projects": [], "tasks": [], "notes": [], "timelines": []}
for item in tool_results:
table = item.get("table")
payload = item.get("data") or {}
rows: list[dict[str, Any]] = []
row = payload.get("row")
if isinstance(row, dict):
rows.append(row)
if isinstance(payload.get("rows"), list):
rows.extend([r for r in payload["rows"] if isinstance(r, dict)])
if table == "projects":
for r in rows:
facts["projects"].append(
{
"id": r.get("id"),
"name": r.get("name"),
"status": r.get("status"),
"clientId": r.get("clientId"),
}
)
elif table == "tasks":
for r in rows:
facts["tasks"].append(
{
"id": r.get("id"),
"title": r.get("title"),
"status": r.get("status"),
"projectId": r.get("projectId"),
}
)
elif table == "notes":
for r in rows:
facts["notes"].append(
{
"id": r.get("id"),
"title": r.get("title"),
"projectId": r.get("projectId"),
}
)
elif table == "timelines":
for r in rows:
facts["timelines"].append(
{
"id": r.get("id"),
"title": r.get("title"),
"date": r.get("date"),
"projectId": r.get("projectId"),
}
)
return facts
async def _run_tool_loop(
worker: WorkerName,
instruction: str,
@@ -254,10 +579,45 @@ async def _run_tool_loop(
llm = get_llm()
llm_with_tools = llm.bind_tools(tools) if tools else llm
resolved_project_id = None
ctx = context.get("context", {}) if isinstance(context, dict) else {}
if isinstance(ctx, dict):
rpid = ctx.get("resolved_project_id")
if isinstance(rpid, str) and rpid:
resolved_project_id = rpid
mandatory_tool_policy = ""
if resolved_project_id:
if worker == "project_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call get_project(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Optionally call list_projects afterward only if needed for disambiguation.\n\n"
)
elif worker == "task_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call list_tasks(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Do not use project name as project_id.\n\n"
)
elif worker == "timeline_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call list_timelines(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Do not use project name as project_id.\n\n"
)
elif worker == "note_agent":
mandatory_tool_policy = (
"MANDATORY TOOL POLICY:\n"
f"- You MUST call list_notes(project_id=\"{resolved_project_id}\") before final answer.\n"
"- Do not use project name as project_id.\n\n"
)
messages: list[Any] = [
SystemMessage(content=worker_prompt),
HumanMessage(
content=(
mandatory_tool_policy +
"Worker instruction:\n"
f"{instruction}\n\n"
"Conversation context:\n"
@@ -278,12 +638,38 @@ async def _run_tool_loop(
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: worker=%s AI->Tool tool_call_id=%s tool=%s args=%s",
worker,
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
tool_output = f"Unknown tool: {call['name']}"
else:
tool_output = await tool_fn.ainvoke(call.get("args", {}))
tool_output_text = str(tool_output)
logger.info(
"deep_agent: worker=%s Tool->AI tool_call_id=%s tool=%s output=%s",
worker,
call_id,
call_name,
tool_output_text[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
logger.info(
"deep_agent: worker=%s appended ToolMessage tool_call_id=%s",
worker,
call_id,
)
structured_llm = llm.with_structured_output(WorkerSummary)
messages.append(SystemMessage(content="You have finished using tools. Summarize findings in max 3 sentences."))
@@ -303,11 +689,18 @@ def _worker_node(worker: WorkerName):
return {"worker_results": []}
instruction = str(task_payload.get("instruction") or state.get("user_message") or "")
logger.info("deep_agent: worker=%s start instruction=%s", worker, instruction[:240])
worker_context = {
"memory": state.get("memory_context", {}),
"context": state.get("context", {}),
}
response, tool_results = await _run_tool_loop(worker, instruction, worker_context)
logger.info(
"deep_agent: worker=%s complete tool_calls=%d entity_counts=%s",
worker,
len(tool_results),
{k: len(v) for k, v in _extract_entity_ids(tool_results).items()},
)
return {
"worker_results": [
@@ -316,6 +709,7 @@ def _worker_node(worker: WorkerName):
"instruction": instruction,
"response": response,
"entity_ids": _extract_entity_ids(tool_results),
"facts": _extract_facts(tool_results),
}
]
}
@@ -333,6 +727,7 @@ def _build_synthesis_prompt(state: GraphState, floating: bool) -> str:
"instruction": result.get("instruction"),
"response": result.get("response"),
"entity_ids": result.get("entity_ids", {}),
"facts": result.get("facts", {}),
}
)
@@ -399,14 +794,25 @@ async def _orchestrator_node_home(state: GraphState) -> GraphState:
if state.get("plan"):
return {}
context = {**state.get("context", {}), **state.get("memory_context", {})}
plan = await _plan_with_llm(str(state.get("user_message", "")), context, floating=False)
user_message = str(state.get("user_message", ""))
base_context = dict(state.get("context", {}))
context = {**base_context, **state.get("memory_context", {})}
if _needs_full_project_snapshot(user_message, floating=False):
resolved_project_id = await _resolve_project_id_from_message(user_message)
if resolved_project_id:
base_context["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s for message=%s", resolved_project_id, user_message[:200])
plan = _build_full_project_snapshot_plan(user_message)
else:
plan = await _plan_with_llm(user_message, context, floating=False)
new_memory = await _apply_memory_updates(str(state.get("user_id", "")), plan.memory_updates, state.get("memory_context", {}))
return {
"plan": [task.model_dump() for task in plan.tasks],
"memory_context": new_memory
"memory_context": new_memory,
"context": base_context,
}
@@ -470,6 +876,9 @@ FLOATING_GRAPH = _build_graph(floating=True)
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
if HOME_SINGLE_AGENT_TEST_MODE:
return await _run_home_single_agent(user_id, message, context)
state = await HOME_GRAPH.ainvoke(
{
"user_id": user_id,
@@ -505,6 +914,11 @@ async def run_home_stream(
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
if HOME_SINGLE_AGENT_TEST_MODE:
async for event in _run_home_single_agent_stream(user_id, message, context):
yield event
return
state_input = {
"user_id": user_id,
"user_message": message,

View File

@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
@@ -32,6 +33,14 @@ from app.config.settings import settings
# Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""