1137 lines
43 KiB
Python
1137 lines
43 KiB
Python
"""Agent run orchestrator.
|
||
|
||
Drives two agent types:
|
||
|
||
* **Local directory agent** — two-step execution per file:
|
||
Step 1 (Classification) uses code to fetch all projects and asks the LLM
|
||
to identify which project the file belongs to and which domains are relevant.
|
||
Step 2 (Processing) fetches existing entities for that project/domains via
|
||
code and runs an LLM with tools — existing data in context enforces
|
||
update-first naturally.
|
||
|
||
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||
Teams, Outlook) and pushes extracted items to Electron.
|
||
|
||
Usage
|
||
-----
|
||
Background tasks are spawned with ``asyncio.create_task()``::
|
||
|
||
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
|
||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||
|
||
The ``trigger_pending_runs`` function is called by the device WS endpoint
|
||
when Electron sends ``device_hello``, so any overdue runs fire immediately
|
||
when the device reconnects.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any
|
||
|
||
from croniter import croniter
|
||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||
from sqlalchemy import select
|
||
|
||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||
from app.agents.note_agent import NOTE_TOOLS
|
||
from app.agents.project_agent import PROJECT_TOOLS
|
||
from app.agents.task_agent import TASK_TOOLS
|
||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||
from app.config.settings import settings
|
||
from app.core.device_manager import DeviceConnectionManager
|
||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
||
from app.core.llm import get_llm
|
||
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
||
from app.db import async_session
|
||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ── Concurrency guard ─────────────────────────────────────────────────────
|
||
# Tracks agent IDs that currently have a run in progress.
|
||
# Prevents multiple simultaneous runs of the same agent within a single process.
|
||
_running_agents: set[str] = set()
|
||
|
||
|
||
def is_agent_running(agent_id: str) -> bool:
|
||
"""Return ``True`` if *agent_id* already has a run in progress."""
|
||
return agent_id in _running_agents
|
||
|
||
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||
|
||
# Max seconds to wait for a single tool-call round-trip (FE → BE).
|
||
_TOOL_CALL_TIMEOUT: int = 30
|
||
# Max LLM reasoning steps for Step 2 processing.
|
||
_MAX_PROCESSING_STEPS: int = 12
|
||
# Max directory recursion depth during scan.
|
||
_MAX_SCAN_DEPTH: int = 5
|
||
|
||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||
# NOTE: "projects" is intentionally excluded — project creation/assignment is
|
||
# handled in code by the runner, never delegated to the Step 2 LLM.
|
||
|
||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||
"tasks": TASK_TOOLS,
|
||
"notes": NOTE_TOOLS,
|
||
"timelines": TIMELINE_TOOLS,
|
||
}
|
||
|
||
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
||
|
||
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
||
"tasks": (
|
||
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
||
"assigned to someone, or tracked with a due date or status."
|
||
),
|
||
"notes": (
|
||
"Documentation, meeting notes, summaries, reference material — "
|
||
"written content meant to be read and referenced rather than acted on."
|
||
),
|
||
"timelines": (
|
||
"Project milestones, deadlines, scheduled events — "
|
||
"specific dates that mark a point in the progress of a project."
|
||
),
|
||
"projects": (
|
||
"High-level project entities — only relevant if the file clearly introduces "
|
||
"a new project or updates the scope of an existing one."
|
||
),
|
||
}
|
||
|
||
_BATCH_FILE_CLASSIFIER_PROMPT = """\
|
||
You are a file classifier for a freelance project management tool.
|
||
|
||
Your job is to match a file to an existing project and identify which data domains to extract.
|
||
|
||
## Project matching rules (STRICT — follow in order)
|
||
|
||
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
||
that overlaps with the existing projects listed below.
|
||
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
||
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
||
when the file has zero meaningful connection to any listed project.
|
||
4. When in doubt, pick the closest match from the list.
|
||
|
||
## Response format
|
||
|
||
Respond ONLY with a JSON object — no markdown, no explanation:
|
||
|
||
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
||
|
||
## Domain definitions (only consider domains in the allowed list)
|
||
|
||
{domain_definitions}
|
||
|
||
## Existing projects
|
||
|
||
{projects_list}
|
||
"""
|
||
|
||
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
||
|
||
_BATCH_PROCESSING_PROMPT = """\
|
||
You are a data extraction assistant for a freelance project management tool.
|
||
|
||
Your task: extract structured data from the file content and persist it using the available tools.
|
||
|
||
## Mandatory process — follow this order for EVERY item you extract
|
||
|
||
1. READ the existing records listed below for the relevant domain.
|
||
2. SEARCH for a match by title, topic, or semantic similarity.
|
||
3. If a match exists → call the update_* tool with the existing record's id.
|
||
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
||
|
||
NEVER call create_* without first checking the existing records.
|
||
NEVER duplicate a record that already exists under a different wording.
|
||
|
||
## Existing records (source of truth)
|
||
|
||
{existing_context}
|
||
|
||
## Context
|
||
|
||
Project: {project_context}
|
||
Domains to extract: {data_types}
|
||
|
||
{custom_prompt_section}
|
||
"""
|
||
|
||
# ── Cloud processing prompt (kept separate for cloud agent) ───────────────
|
||
|
||
_BATCH_CLOUD_PROCESSING_PROMPT = """\
|
||
You are a data extraction and management assistant for a freelance project
|
||
management tool.
|
||
|
||
Available tools:
|
||
Filesystem : read_file_content, list_directory, get_file_metadata
|
||
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||
Notes : list_notes, get_note, create_note, update_note
|
||
Timelines : list_timelines, create_timeline, update_timeline
|
||
Projects : list_all_projects, get_project, create_project, update_project
|
||
|
||
Your task:
|
||
1. Read the full content of each file below using read_file_content.
|
||
2. For each piece of information found, ALWAYS try to match and update an
|
||
existing record before creating a new one.
|
||
3. ONLY act on these entity types: {data_types}.
|
||
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||
5. If a file contains no relevant data for the target entity types, skip it.
|
||
|
||
{project_context}
|
||
|
||
Files to process:
|
||
{file_list}
|
||
|
||
{custom_prompt_section}
|
||
|
||
After processing all files, respond with a brief summary of what you updated
|
||
and what you created.
|
||
"""
|
||
|
||
|
||
# ── Cron helper ────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
||
"""Return ``True`` if the next scheduled run time has already passed.
|
||
|
||
Always validates the cron expression first — an invalid expression returns
|
||
``False`` (fail-safe: never trigger an unparseable schedule).
|
||
"""
|
||
try:
|
||
now = datetime.now(timezone.utc)
|
||
if last_run_at is None:
|
||
croniter(schedule_cron, now)
|
||
return True
|
||
ts = last_run_at
|
||
if ts.tzinfo is None:
|
||
ts = ts.replace(tzinfo=timezone.utc)
|
||
cron = croniter(schedule_cron, ts)
|
||
next_run: datetime = cron.get_next(datetime)
|
||
return now >= next_run
|
||
except Exception as exc:
|
||
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
||
return False
|
||
|
||
|
||
# ── WS executor for agent context ─────────────────────────────────────────
|
||
|
||
|
||
def _make_agent_executor(
|
||
user_id: str,
|
||
device_mgr: DeviceConnectionManager,
|
||
run_context: dict | None = None,
|
||
) -> Any:
|
||
"""Create a WS callback for ``set_client_executor()`` so that all tools
|
||
can use ``execute_on_client()`` during an agent run.
|
||
|
||
If *run_context* is provided it is attached to every ``tool_call`` frame
|
||
so the Electron client can attribute actions to the correct agent run.
|
||
"""
|
||
async def _executor(payload: dict) -> dict:
|
||
payload["type"] = "tool_call"
|
||
if run_context:
|
||
payload["run_context"] = run_context
|
||
call_id = payload["id"]
|
||
fut = device_mgr.create_pending_call(user_id, call_id)
|
||
await device_mgr.send_frame(user_id, payload)
|
||
return await asyncio.wait_for(fut, timeout=_TOOL_CALL_TIMEOUT)
|
||
return _executor
|
||
|
||
|
||
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
||
|
||
|
||
def _as_text(content: Any) -> str:
|
||
if content is None:
|
||
return ""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for item in content:
|
||
if isinstance(item, str):
|
||
parts.append(item)
|
||
elif isinstance(item, dict):
|
||
text = item.get("text")
|
||
if isinstance(text, str):
|
||
parts.append(text)
|
||
return "".join(parts)
|
||
return str(content)
|
||
|
||
|
||
async def _run_agent_with_tools(
|
||
*,
|
||
system_prompt: str,
|
||
user_message: str,
|
||
tools: list[Any],
|
||
max_steps: int,
|
||
user_id: str = "",
|
||
langfuse_prompt: Any = None,
|
||
agent_name: str = "batch-agent",
|
||
) -> str:
|
||
"""Run an LLM agent with tool-calling, returning the final text response."""
|
||
lf = get_langfuse()
|
||
llm = get_llm()
|
||
llm_with_tools = llm.bind_tools(tools)
|
||
messages: list[Any] = [
|
||
SystemMessage(content=system_prompt),
|
||
HumanMessage(content=user_message),
|
||
]
|
||
|
||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||
|
||
_span_ctx = (
|
||
lf.start_as_current_observation(
|
||
as_type="span",
|
||
name=agent_name,
|
||
user_id=user_id or None,
|
||
input=user_message,
|
||
)
|
||
if lf else None
|
||
)
|
||
_span = _span_ctx.__enter__() if _span_ctx else None
|
||
|
||
try:
|
||
for _ in range(max_steps):
|
||
_gen_ctx = (
|
||
lf.start_as_current_observation(
|
||
as_type="generation",
|
||
name=f"{agent_name}-llm",
|
||
model=settings.LLM_MODEL,
|
||
prompt=langfuse_prompt,
|
||
input=messages,
|
||
)
|
||
if lf else None
|
||
)
|
||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||
if _gen_ctx:
|
||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||
_gen_ctx.__exit__(None, None, None)
|
||
|
||
messages.append(response)
|
||
|
||
if not response.tool_calls:
|
||
final_text = _as_text(response.content)
|
||
if _span:
|
||
_span.update(output=final_text)
|
||
return final_text
|
||
|
||
for call in response.tool_calls:
|
||
call_id = str(call.get("id", ""))
|
||
call_name = str(call.get("name", ""))
|
||
call_args = call.get("args", {})
|
||
logger.info(
|
||
"agent_runner: tool_call name=%s args=%s",
|
||
call_name,
|
||
json.dumps(call_args, ensure_ascii=True)[:800],
|
||
)
|
||
|
||
tool_fn = tool_map.get(call_name)
|
||
if tool_fn is None:
|
||
tool_output = f"Unknown tool: {call_name}"
|
||
else:
|
||
tool_output = await tool_fn.ainvoke(call_args)
|
||
|
||
logger.info(
|
||
"agent_runner: tool_result name=%s output=%s",
|
||
call_name,
|
||
str(tool_output)[:200],
|
||
)
|
||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||
|
||
final = await llm.ainvoke(messages)
|
||
final_text = _as_text(final.content)
|
||
if _span:
|
||
_span.update(output=final_text)
|
||
return final_text
|
||
finally:
|
||
if _span_ctx:
|
||
_span_ctx.__exit__(None, None, None)
|
||
if lf:
|
||
lf.flush()
|
||
|
||
|
||
# ── Tool list builder ─────────────────────────────────────────────────────
|
||
|
||
|
||
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||
"""Build the tool list for processing based on user's data_types selection."""
|
||
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||
for dt in data_types:
|
||
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||
if dt_tools:
|
||
tools.extend(dt_tools)
|
||
return tools
|
||
|
||
|
||
# ── Code-based directory scanner ─────────────────────────────────────────
|
||
|
||
|
||
async def _scan_directories(
|
||
paths: list[str],
|
||
extensions: list[str],
|
||
last_run_at: datetime | None,
|
||
) -> list[str]:
|
||
"""Walk directories via WS tool calls and return filtered file paths.
|
||
|
||
Recursion is capped at ``_MAX_SCAN_DEPTH``. Files are filtered by
|
||
extension (if configured) and by modification date (if ``last_run_at``
|
||
is set). Fails open: if metadata cannot be read, the file is included.
|
||
"""
|
||
all_files: list[str] = []
|
||
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
||
|
||
async def _walk(path: str, depth: int) -> None:
|
||
if depth > _MAX_SCAN_DEPTH:
|
||
return
|
||
try:
|
||
result = await execute_on_client(action="list_directory", data={"path": path})
|
||
except Exception as exc:
|
||
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
||
return
|
||
for entry in result.get("entries", []):
|
||
entry_path = entry.get("path", "")
|
||
if not entry_path:
|
||
continue
|
||
if entry.get("type") == "directory":
|
||
await _walk(entry_path, depth + 1)
|
||
elif entry.get("type") == "file":
|
||
if ext_set:
|
||
dot_pos = entry_path.rfind(".")
|
||
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
||
if file_ext not in ext_set:
|
||
continue
|
||
all_files.append(entry_path)
|
||
|
||
for root in paths:
|
||
await _walk(root, depth=0)
|
||
|
||
if last_run_at is None:
|
||
return all_files
|
||
|
||
# Filter by modification date.
|
||
last_run_ms = int(last_run_at.timestamp() * 1000)
|
||
filtered: list[str] = []
|
||
for file_path in all_files:
|
||
try:
|
||
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||
modified_at = meta.get("modifiedAt")
|
||
if modified_at is None:
|
||
filtered.append(file_path)
|
||
continue
|
||
if isinstance(modified_at, (int, float)):
|
||
mod_ms = int(modified_at)
|
||
else:
|
||
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
||
if mod_ms > last_run_ms:
|
||
filtered.append(file_path)
|
||
except Exception:
|
||
filtered.append(file_path) # fail-open
|
||
|
||
return filtered
|
||
|
||
|
||
# ── Code-based entity fetchers ────────────────────────────────────────────
|
||
|
||
|
||
async def _fetch_projects() -> list[dict]:
|
||
"""Fetch all projects from the Electron client via WS."""
|
||
try:
|
||
result = await execute_on_client(action="select", table="projects")
|
||
return result.get("rows", [])
|
||
except Exception as exc:
|
||
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
||
return []
|
||
|
||
|
||
_DOMAIN_TABLE: dict[str, str] = {
|
||
"tasks": "tasks",
|
||
"notes": "notes",
|
||
"timelines": "timelines",
|
||
"projects": "projects",
|
||
}
|
||
|
||
|
||
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
||
"""Fetch existing rows for a domain, scoped to a project where applicable."""
|
||
table = _DOMAIN_TABLE.get(domain)
|
||
if not table:
|
||
return []
|
||
filters: dict[str, Any] = {}
|
||
if project_id != "standalone" and domain != "projects":
|
||
filters["projectId"] = project_id
|
||
try:
|
||
result = await execute_on_client(
|
||
action="select",
|
||
table=table,
|
||
filters=filters if filters else None,
|
||
)
|
||
return result.get("rows", [])
|
||
except Exception as exc:
|
||
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
||
return []
|
||
|
||
|
||
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
||
"""Format existing entity rows as a readable context block for the LLM.
|
||
|
||
Includes enough detail per record for the LLM to make a confident
|
||
update-vs-create decision without overwhelming the context.
|
||
Note content is truncated to 200 chars to stay within token budget.
|
||
"""
|
||
if not rows:
|
||
return f"No existing {domain}."
|
||
lines: list[str] = []
|
||
for r in rows:
|
||
if domain == "tasks":
|
||
desc = r.get("description") or ""
|
||
desc_part = f" — {desc[:120]}" if desc else ""
|
||
assignee = r.get("assignee") or r.get("assignees") or ""
|
||
due = r.get("dueDate") or r.get("due_date") or ""
|
||
meta = ", ".join(filter(None, [
|
||
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
||
f"assignee: {assignee}" if assignee else "",
|
||
f"due: {due}" if due else "",
|
||
]))
|
||
lines.append(
|
||
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
||
f" ({meta}, id: {r['id']})"
|
||
)
|
||
elif domain == "notes":
|
||
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
||
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
||
lines.append(
|
||
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
||
)
|
||
elif domain == "timelines":
|
||
lines.append(
|
||
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
||
)
|
||
elif domain == "projects":
|
||
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
||
summary_part = f" — {summary}" if summary else ""
|
||
lines.append(
|
||
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
||
f" (id: {r['id']})"
|
||
)
|
||
return f"Existing {domain}:\n" + "\n".join(lines)
|
||
|
||
|
||
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
||
|
||
|
||
async def _classify_file(
|
||
file_path: str,
|
||
file_content: str,
|
||
projects: list[dict],
|
||
config_data_types: list[str],
|
||
) -> tuple[str, list[str], str | None]:
|
||
"""Call the LLM to classify a file by project and relevant domains.
|
||
|
||
Returns ``(project_id_or_"new", domains, new_project_name_or_None)``.
|
||
- ``project_id`` is an existing project UUID, or ``"new"`` when no match found.
|
||
- ``new_project_name`` is only set when ``project_id == "new"``.
|
||
Falls back to ``("new", config_data_types, None)`` on any error.
|
||
"""
|
||
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
||
|
||
if not file_content.strip():
|
||
return fallback
|
||
|
||
valid_project_ids = {p["id"] for p in projects}
|
||
|
||
def _fmt_project(p: dict) -> str:
|
||
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||
summary_part = f" — {summary[:100]}" if summary else ""
|
||
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
||
|
||
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
||
|
||
domain_definitions = "\n".join(
|
||
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
||
for d in config_data_types
|
||
if d in _DOMAIN_DESCRIPTIONS
|
||
)
|
||
|
||
step1_template, step1_prompt_obj = get_prompt_or_fallback(
|
||
"batch_file_classifier", _BATCH_FILE_CLASSIFIER_PROMPT
|
||
)
|
||
system = step1_template.format(
|
||
domain_definitions=domain_definitions,
|
||
projects_list=projects_list,
|
||
)
|
||
|
||
lf = get_langfuse()
|
||
llm = get_llm()
|
||
classifier_messages = [
|
||
SystemMessage(content=system),
|
||
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
||
]
|
||
try:
|
||
if lf:
|
||
with lf.start_as_current_observation(
|
||
as_type="generation",
|
||
name="step1-classifier",
|
||
model=settings.LLM_ROUTER_MODEL,
|
||
prompt=step1_prompt_obj,
|
||
input=classifier_messages,
|
||
) as gen:
|
||
response = await llm.ainvoke(classifier_messages)
|
||
gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||
else:
|
||
response = await llm.ainvoke(classifier_messages)
|
||
raw = _as_text(response.content).strip()
|
||
# Strip markdown fences if the model wraps the JSON.
|
||
if raw.startswith("```"):
|
||
raw = raw.split("```")[1]
|
||
if raw.startswith("json"):
|
||
raw = raw[4:]
|
||
parsed = json.loads(raw.strip())
|
||
raw_project_id: str = str(parsed.get("project_id") or "new")
|
||
# Reject hallucinated UUIDs — only accept ids that exist in the fetched list.
|
||
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
||
new_project_name: str | None = (
|
||
str(parsed["new_project_name"]).strip() or None
|
||
if project_id == "new" and parsed.get("new_project_name")
|
||
else None
|
||
)
|
||
domains: list[str] = [
|
||
d for d in parsed.get("domains", [])
|
||
if d in config_data_types
|
||
]
|
||
if not domains:
|
||
domains = list(config_data_types)
|
||
return project_id, domains, new_project_name
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
||
)
|
||
return fallback
|
||
|
||
|
||
# ── Local agent runner (two-step per file) ────────────────────────────────
|
||
|
||
|
||
async def run_local_agent(
|
||
user_id: str,
|
||
config: LocalAgentConfig,
|
||
run_log: AgentRunLog,
|
||
device_mgr: DeviceConnectionManager,
|
||
run_context: dict | None = None,
|
||
) -> None:
|
||
"""Execute a local directory agent run using a two-step approach per file.
|
||
|
||
Step 1 — Classification (code + 1 LLM call per file, no tools):
|
||
Code scans directories and fetches all projects via WS.
|
||
For each file, LLM identifies the project and relevant domains.
|
||
|
||
Step 2 — Processing (code + 1 LLM call per file, with tools):
|
||
Code fetches existing entities for the identified project/domains.
|
||
LLM receives file content + existing entities in context and uses
|
||
tools to update existing records or create new ones.
|
||
"""
|
||
run_id = run_log.id
|
||
agent_id = (run_context or {}).get("agent_id") or config.id
|
||
_running_agents.add(agent_id)
|
||
|
||
# ── Device online check ─────────────────────────────────────────
|
||
target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
|
||
is_online = (
|
||
device_mgr.is_online(user_id, target_device_id)
|
||
if target_device_id
|
||
else device_mgr.is_online(user_id)
|
||
)
|
||
|
||
if not is_online:
|
||
logger.info(
|
||
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||
run_id,
|
||
target_device_id or "<any>",
|
||
user_id,
|
||
)
|
||
await _finalize_run(
|
||
run_log,
|
||
status="error",
|
||
errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
|
||
)
|
||
return
|
||
|
||
# ── Set up WS executor for tools ────────────────────────────────
|
||
executor = _make_agent_executor(user_id, device_mgr, run_context)
|
||
set_client_executor(executor)
|
||
|
||
errors: list[str] = []
|
||
items_processed = 0
|
||
items_created = 0
|
||
|
||
custom_section = (
|
||
f"User instructions:\n{config.prompt_template}"
|
||
if config.prompt_template
|
||
else ""
|
||
)
|
||
|
||
try:
|
||
# ── Code: scan directories ───────────────────────────────────
|
||
logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id)
|
||
file_paths = await _scan_directories(
|
||
paths=config.directory_paths,
|
||
extensions=config.file_extensions or [],
|
||
last_run_at=config.last_run_at,
|
||
)
|
||
logger.info(
|
||
"agent_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths)
|
||
)
|
||
|
||
if not file_paths:
|
||
await _finalize_run(run_log, status="success", items_processed=0, items_created=0)
|
||
return
|
||
|
||
# ── Code: fetch all projects once ────────────────────────────
|
||
projects = await _fetch_projects()
|
||
|
||
for file_path in file_paths:
|
||
try:
|
||
# Read file content via code.
|
||
file_result = await execute_on_client(
|
||
action="read_file_content", data={"path": file_path}
|
||
)
|
||
file_content: str = file_result.get("content", "")
|
||
if not file_content:
|
||
logger.debug("agent_runner: run=%s skipping empty file %r", run_id, file_path)
|
||
continue
|
||
|
||
items_processed += 1
|
||
|
||
# Step 1 — classify file.
|
||
project_id, domains, new_project_name = await _classify_file(
|
||
file_path=file_path,
|
||
file_content=file_content,
|
||
projects=projects,
|
||
config_data_types=config.data_types,
|
||
)
|
||
logger.info(
|
||
"agent_runner: run=%s file=%r → project=%s new_name=%r domains=%s",
|
||
run_id,
|
||
file_path,
|
||
project_id,
|
||
new_project_name,
|
||
domains,
|
||
)
|
||
|
||
# Step 2 — resolve project_id via CODE, then fetch entities.
|
||
# Project creation is NEVER delegated to the Step 2 LLM.
|
||
if project_id == "new":
|
||
proj_name = new_project_name or "Untitled Project"
|
||
try:
|
||
proj_result = await execute_on_client(
|
||
action="insert",
|
||
table="projects",
|
||
data={"name": proj_name, "clientId": None},
|
||
)
|
||
created = proj_result.get("row", {})
|
||
effective_project_id = created.get("id", "standalone")
|
||
# Add to local list so subsequent files can match it.
|
||
if "id" in created:
|
||
projects.append(created)
|
||
logger.info(
|
||
"agent_runner: run=%s created project %r id=%s",
|
||
run_id, proj_name, effective_project_id,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"agent_runner: run=%s failed to create project %r: %s",
|
||
run_id, proj_name, exc,
|
||
)
|
||
effective_project_id = "standalone"
|
||
proj_name = "unknown"
|
||
project_context = (
|
||
f"Project: {proj_name} (id: {effective_project_id}). "
|
||
"Always set projectId to this id on every record you create."
|
||
)
|
||
else:
|
||
effective_project_id = project_id
|
||
proj = next((p for p in projects if p["id"] == project_id), None)
|
||
proj_name = proj.get("name", project_id) if proj else project_id
|
||
project_context = (
|
||
f"Project: {proj_name} (id: {project_id}). "
|
||
"Always set projectId to this id on every record you create."
|
||
)
|
||
|
||
# "projects" domain is never passed to Step 2 — handled above in code.
|
||
domains = [d for d in domains if d != "projects"]
|
||
|
||
existing_blocks: list[str] = []
|
||
for domain in domains:
|
||
rows = await _fetch_domain_entities(domain, effective_project_id)
|
||
existing_blocks.append(_format_entities_for_context(domain, rows))
|
||
|
||
existing_context = "\n\n".join(existing_blocks)
|
||
|
||
step2_template, step2_prompt_obj = get_prompt_or_fallback(
|
||
"batch_processing", _BATCH_PROCESSING_PROMPT
|
||
)
|
||
system_prompt = step2_template.format(
|
||
existing_context=existing_context,
|
||
project_context=project_context,
|
||
data_types=", ".join(domains),
|
||
custom_prompt_section=custom_section,
|
||
)
|
||
|
||
processing_tools = _build_processing_tools(domains)
|
||
|
||
result_text = await _run_agent_with_tools(
|
||
system_prompt=system_prompt,
|
||
user_message=(
|
||
f"Process this file and extract relevant information.\n\n"
|
||
f"File: {file_path}\n\nContent:\n{file_content}"
|
||
),
|
||
tools=processing_tools,
|
||
max_steps=_MAX_PROCESSING_STEPS,
|
||
user_id=user_id,
|
||
langfuse_prompt=step2_prompt_obj,
|
||
agent_name="step2-processor",
|
||
)
|
||
logger.info(
|
||
"agent_runner: run=%s file=%r result=%s",
|
||
run_id,
|
||
file_path,
|
||
result_text[:200],
|
||
)
|
||
|
||
except Exception as exc:
|
||
errors.append(f"Error processing '{file_path}': {exc}")
|
||
logger.error(
|
||
"agent_runner: run=%s file=%r failed: %s", run_id, file_path, exc
|
||
)
|
||
|
||
except Exception as exc:
|
||
errors.append(f"Agent run failed: {exc}")
|
||
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
|
||
finally:
|
||
_running_agents.discard(agent_id)
|
||
clear_client_executor()
|
||
|
||
# ── Finalise ────────────────────────────────────────────────────
|
||
if errors and items_processed == 0:
|
||
final_status = "error"
|
||
elif errors:
|
||
final_status = "partial"
|
||
else:
|
||
final_status = "success"
|
||
|
||
await _finalize_run(
|
||
run_log,
|
||
status=final_status,
|
||
items_processed=items_processed,
|
||
items_created=items_created,
|
||
errors=errors,
|
||
)
|
||
logger.info(
|
||
"agent_runner: run=%s done status=%s processed=%d errors=%d",
|
||
run_id,
|
||
final_status,
|
||
items_processed,
|
||
len(errors),
|
||
)
|
||
|
||
# Notify Electron that the run is complete.
|
||
if run_context and device_mgr.is_online(user_id):
|
||
try:
|
||
await device_mgr.send_frame(user_id, {
|
||
"type": "run_complete",
|
||
"run_context": run_context,
|
||
"status": final_status,
|
||
})
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"agent_runner: run=%s failed to send run_complete: %s", run_id, exc
|
||
)
|
||
|
||
|
||
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||
|
||
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||
|
||
|
||
async def run_cloud_agent(
|
||
user_id: str,
|
||
config: CloudAgentConfig,
|
||
run_log: AgentRunLog,
|
||
device_mgr: DeviceConnectionManager,
|
||
) -> None:
|
||
"""Execute a cloud connector agent run end-to-end.
|
||
|
||
Steps:
|
||
|
||
1. Verify the user's device is online.
|
||
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
|
||
3. Instantiate the provider client (Gmail or MS Graph).
|
||
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||
the first run) applying ``config.filter_config`` filters.
|
||
5. For each message/email call the LLM to extract structured items.
|
||
6. Push each item to Electron as an ``insert`` tool-call.
|
||
7. If the provider refreshed its access token, re-encrypt and write it
|
||
back to ``config.oauth_token_encrypted``.
|
||
8. Persist the run outcome via ``_finalize_run``.
|
||
"""
|
||
run_id = run_log.id
|
||
|
||
# ── 1. Device online check ─────────────────────────────────────────
|
||
if not device_mgr.is_online(user_id):
|
||
logger.info(
|
||
"agent_runner: skip cloud run=%s — no device online for user=%s",
|
||
run_id,
|
||
user_id,
|
||
)
|
||
await _finalize_run(
|
||
run_log,
|
||
status="error",
|
||
errors=["No connected device — cloud agent results cannot be delivered"],
|
||
)
|
||
return
|
||
|
||
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
|
||
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||
|
||
if not config.oauth_token_encrypted:
|
||
await _finalize_run(
|
||
run_log,
|
||
status="error",
|
||
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||
)
|
||
return
|
||
|
||
try:
|
||
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||
except ValueError as exc:
|
||
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
||
await _finalize_run(
|
||
run_log,
|
||
status="error",
|
||
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||
)
|
||
return
|
||
|
||
# ── 3. Instantiate provider client ────────────────────────────────
|
||
try:
|
||
provider = get_provider(config.provider, credentials_info)
|
||
except ValueError as exc:
|
||
await _finalize_run(run_log, status="error", errors=[str(exc)])
|
||
return
|
||
|
||
# ── 4. Fetch messages ─────────────────────────────────────────────
|
||
since: datetime | None = config.last_run_at
|
||
if since is None:
|
||
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||
if since.tzinfo is None:
|
||
since = since.replace(tzinfo=timezone.utc)
|
||
|
||
errors: list[str] = []
|
||
items_processed = 0
|
||
items_created = 0
|
||
|
||
try:
|
||
if config.provider == "gmail":
|
||
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||
filter_config=config.filter_config,
|
||
since=since,
|
||
)
|
||
elif config.provider == "outlook":
|
||
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
|
||
filter_config=config.filter_config,
|
||
since=since,
|
||
)
|
||
elif config.provider == "teams":
|
||
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||
filter_config=config.filter_config,
|
||
since=since,
|
||
)
|
||
else:
|
||
raw_messages = []
|
||
except RuntimeError as exc:
|
||
logger.error(
|
||
"agent_runner: provider fetch failed for cloud agent %s: %s", config.id, exc
|
||
)
|
||
await _finalize_run(
|
||
run_log,
|
||
status="error",
|
||
errors=[f"Provider fetch failed: {exc}"],
|
||
update_config_last_run=True,
|
||
config_id=config.id,
|
||
config_type="cloud",
|
||
)
|
||
return
|
||
|
||
logger.info(
|
||
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
||
config.id,
|
||
len(raw_messages),
|
||
config.provider,
|
||
user_id,
|
||
)
|
||
|
||
# ── 5–6. Extract + insert via LLM with tools ─────────────────────
|
||
executor = _make_agent_executor(user_id, device_mgr)
|
||
set_client_executor(executor)
|
||
|
||
try:
|
||
processing_tools = _build_processing_tools(config.data_types)
|
||
custom_section = (
|
||
f"User instructions:\n{config.prompt_template}"
|
||
if config.prompt_template
|
||
else ""
|
||
)
|
||
|
||
for msg in raw_messages:
|
||
content_text = msg.as_text
|
||
if not content_text:
|
||
continue
|
||
items_processed += 1
|
||
|
||
cloud_template, cloud_prompt_obj = get_prompt_or_fallback(
|
||
"batch_cloud_processing", _BATCH_CLOUD_PROCESSING_PROMPT
|
||
)
|
||
processing_prompt = cloud_template.format(
|
||
data_types=", ".join(config.data_types),
|
||
project_context="Determine the appropriate project from the message context.",
|
||
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||
custom_prompt_section=custom_section,
|
||
)
|
||
|
||
try:
|
||
await _run_agent_with_tools(
|
||
system_prompt=processing_prompt,
|
||
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||
tools=processing_tools,
|
||
max_steps=_MAX_PROCESSING_STEPS,
|
||
user_id=user_id,
|
||
langfuse_prompt=cloud_prompt_obj,
|
||
agent_name="cloud-processor",
|
||
)
|
||
except Exception as exc:
|
||
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||
finally:
|
||
clear_client_executor()
|
||
|
||
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||
refreshed = getattr(provider, "refreshed_credentials", None)
|
||
if refreshed:
|
||
try:
|
||
new_encrypted = encrypt_token(refreshed)
|
||
async with async_session() as db:
|
||
cfg_result = await db.execute(
|
||
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||
)
|
||
cfg_row = cfg_result.scalar_one_or_none()
|
||
if cfg_row:
|
||
cfg_row.oauth_token_encrypted = new_encrypted
|
||
await db.commit()
|
||
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"agent_runner: failed to persist refreshed token for agent %s: %s",
|
||
config.id,
|
||
exc,
|
||
)
|
||
|
||
# ── 8. Finalise ────────────────────────────────────────────────────
|
||
if errors and items_created == 0:
|
||
final_status = "error"
|
||
elif errors:
|
||
final_status = "partial"
|
||
else:
|
||
final_status = "success"
|
||
|
||
await _finalize_run(
|
||
run_log,
|
||
status=final_status,
|
||
items_processed=items_processed,
|
||
items_created=items_created,
|
||
errors=errors,
|
||
update_config_last_run=True,
|
||
config_id=config.id,
|
||
config_type="cloud",
|
||
)
|
||
logger.info(
|
||
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
||
run_id,
|
||
final_status,
|
||
items_processed,
|
||
items_created,
|
||
len(errors),
|
||
)
|
||
|
||
|
||
# ── Pending-run trigger ─────────────────────────────────────────────────────
|
||
|
||
|
||
async def trigger_pending_runs(
|
||
user_id: str,
|
||
device_id: str,
|
||
device_mgr: DeviceConnectionManager,
|
||
) -> None:
|
||
"""Dispatch any overdue agent runs after an Electron device connects.
|
||
|
||
Called as a background task from the device WS endpoint on ``device_hello``.
|
||
"""
|
||
logger.info(
|
||
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
|
||
user_id,
|
||
device_id,
|
||
)
|
||
return
|
||
|
||
|
||
# ── Internal helper ─────────────────────────────────────────────────────────
|
||
|
||
|
||
async def _finalize_run(
|
||
run_log: AgentRunLog,
|
||
*,
|
||
status: str,
|
||
items_processed: int = 0,
|
||
items_created: int = 0,
|
||
errors: list[str] | None = None,
|
||
update_config_last_run: bool = False,
|
||
config_id: str | None = None,
|
||
config_type: str | None = None,
|
||
) -> None:
|
||
"""Persist the run outcome and optionally update ``last_run_at`` on the config."""
|
||
now = datetime.now(timezone.utc)
|
||
try:
|
||
async with async_session() as db:
|
||
managed = await db.merge(run_log)
|
||
managed.status = status
|
||
managed.items_processed = items_processed
|
||
managed.items_created = items_created
|
||
managed.errors = errors or []
|
||
managed.completed_at = now
|
||
|
||
if update_config_last_run and config_id:
|
||
if config_type == "local":
|
||
cfg_result = await db.execute(
|
||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||
)
|
||
cfg = cfg_result.scalar_one_or_none()
|
||
if cfg:
|
||
cfg.last_run_at = now
|
||
elif config_type == "cloud":
|
||
cfg_result = await db.execute(
|
||
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||
)
|
||
cfg = cfg_result.scalar_one_or_none()
|
||
if cfg:
|
||
cfg.last_run_at = now
|
||
|
||
await db.commit()
|
||
except Exception as exc:
|
||
logger.error(
|
||
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
||
)
|