911 lines
34 KiB
Python
911 lines
34 KiB
Python
"""Agent run orchestrator — adapted for Batch Agent Service.
|
|
|
|
Key changes from monolith app/core/agent_runner.py:
|
|
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
|
- set_current_user / clear_current_user replace set_client_executor.
|
|
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
|
of SQLAlchemy model objects.
|
|
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
|
- Cloud agent import path changed to app.integrations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
from sqlalchemy import select
|
|
|
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
|
from shared.agents.note_agent import NOTE_TOOLS
|
|
from shared.agents.project_agent import PROJECT_TOOLS
|
|
from shared.agents.task_agent import TASK_TOOLS
|
|
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
|
from shared.llm import get_llm
|
|
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
|
|
import app.tracing as tracing
|
|
from shared.db import async_session
|
|
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
|
from shared.redis import redis_client, ws_out_channel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ── Concurrency guard ─────────────────────────────────────────────────────
|
|
_running_agents: set[str] = set()
|
|
|
|
|
|
def is_agent_running(agent_id: str) -> bool:
|
|
return agent_id in _running_agents
|
|
|
|
|
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
|
_TOOL_CALL_TIMEOUT: int = 30
|
|
_MAX_PROCESSING_STEPS: int = 12
|
|
_MAX_SCAN_DEPTH: int = 5
|
|
|
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
|
"tasks": TASK_TOOLS,
|
|
"notes": NOTE_TOOLS,
|
|
"timelines": TIMELINE_TOOLS,
|
|
}
|
|
|
|
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
|
|
|
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
|
"tasks": (
|
|
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
|
"assigned to someone, or tracked with a due date or status."
|
|
),
|
|
"notes": (
|
|
"Documentation, meeting notes, summaries, reference material — "
|
|
"written content meant to be read and referenced rather than acted on."
|
|
),
|
|
"timelines": (
|
|
"Project milestones, deadlines, scheduled events — "
|
|
"specific dates that mark a point in the progress of a project."
|
|
),
|
|
"projects": (
|
|
"High-level project entities — only relevant if the file clearly introduces "
|
|
"a new project or updates the scope of an existing one."
|
|
),
|
|
}
|
|
|
|
_STEP1_SYSTEM_PROMPT = """\
|
|
You are a file classifier for a freelance project management tool.
|
|
|
|
Your job is to match a file to an existing project and identify which data domains to extract.
|
|
|
|
## Project matching rules (STRICT — follow in order)
|
|
|
|
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
|
that overlaps with the existing projects listed below.
|
|
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
|
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
|
when the file has zero meaningful connection to any listed project.
|
|
4. When in doubt, pick the closest match from the list.
|
|
|
|
## Response format
|
|
|
|
Respond ONLY with a JSON object — no markdown, no explanation:
|
|
|
|
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
|
|
|
## Domain definitions (only consider domains in the allowed list)
|
|
|
|
{domain_definitions}
|
|
|
|
## Existing projects
|
|
|
|
{projects_list}
|
|
"""
|
|
|
|
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
|
|
|
_PROCESSING_SYSTEM_PROMPT = """\
|
|
You are a data extraction assistant for a freelance project management tool.
|
|
|
|
Your task: extract structured data from the file content and persist it using the available tools.
|
|
|
|
## Mandatory process — follow this order for EVERY item you extract
|
|
|
|
1. READ the existing records listed below for the relevant domain.
|
|
2. SEARCH for a match by title, topic, or semantic similarity.
|
|
3. If a match exists → call the update_* tool with the existing record's id.
|
|
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
|
|
|
NEVER call create_* without first checking the existing records.
|
|
NEVER duplicate a record that already exists under a different wording.
|
|
|
|
## Existing records (source of truth)
|
|
|
|
{existing_context}
|
|
|
|
## Context
|
|
|
|
Project: {project_context}
|
|
Domains to extract: {data_types}
|
|
|
|
{custom_prompt_section}
|
|
"""
|
|
|
|
# ── Cloud processing prompt ───────────────────────────────────────────────
|
|
|
|
_CLOUD_PROCESSING_PROMPT = """\
|
|
You are a data extraction and management assistant for a freelance project
|
|
management tool.
|
|
|
|
Available tools:
|
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
|
Notes : list_notes, get_note, create_note, update_note
|
|
Timelines : list_timelines, create_timeline, update_timeline
|
|
Projects : list_all_projects, get_project, create_project, update_project
|
|
|
|
Your task:
|
|
1. Read the full content of each file below using read_file_content.
|
|
2. For each piece of information found, ALWAYS try to match and update an
|
|
existing record before creating a new one.
|
|
3. ONLY act on these entity types: {data_types}.
|
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
|
5. If a file contains no relevant data for the target entity types, skip it.
|
|
|
|
{project_context}
|
|
|
|
Files to process:
|
|
{file_list}
|
|
|
|
{custom_prompt_section}
|
|
|
|
After processing all files, respond with a brief summary of what you updated
|
|
and what you created.
|
|
"""
|
|
|
|
|
|
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
|
|
|
|
|
def _as_text(content: Any) -> str:
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
elif isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "".join(parts)
|
|
return str(content)
|
|
|
|
|
|
async def _run_agent_with_tools(
|
|
*,
|
|
system_prompt: str,
|
|
user_message: str,
|
|
tools: list[Any],
|
|
max_steps: int,
|
|
langfuse_handler: Any | None = None,
|
|
) -> str:
|
|
"""Run an LLM agent with tool-calling, returning the final text response."""
|
|
callbacks = [langfuse_handler] if langfuse_handler else None
|
|
llm = get_llm(callbacks=callbacks)
|
|
llm_with_tools = llm.bind_tools(tools)
|
|
messages: list[Any] = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(content=user_message),
|
|
]
|
|
|
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
|
|
for _ in range(max_steps):
|
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
messages.append(response)
|
|
|
|
if not response.tool_calls:
|
|
return _as_text(response.content)
|
|
|
|
for call in response.tool_calls:
|
|
call_id = str(call.get("id", ""))
|
|
call_name = str(call.get("name", ""))
|
|
call_args = call.get("args", {})
|
|
logger.info(
|
|
"agent_runner: tool_call name=%s args=%s",
|
|
call_name,
|
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
)
|
|
|
|
tool_fn = tool_map.get(call_name)
|
|
if tool_fn is None:
|
|
tool_output = f"Unknown tool: {call_name}"
|
|
else:
|
|
tool_output = await tool_fn.ainvoke(call_args)
|
|
|
|
logger.info(
|
|
"agent_runner: tool_result name=%s output=%s",
|
|
call_name,
|
|
str(tool_output)[:200],
|
|
)
|
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
|
|
final = await llm.ainvoke(messages)
|
|
return _as_text(final.content)
|
|
|
|
|
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
|
|
|
|
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
|
for dt in data_types:
|
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
|
if dt_tools:
|
|
tools.extend(dt_tools)
|
|
return tools
|
|
|
|
|
|
# ── Code-based directory scanner ─────────────────────────────────────────
|
|
|
|
|
|
async def _scan_directories(
|
|
paths: list[str],
|
|
extensions: list[str],
|
|
last_run_at: datetime | None,
|
|
) -> list[str]:
|
|
all_files: list[str] = []
|
|
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
|
|
|
async def _walk(path: str, depth: int) -> None:
|
|
if depth > _MAX_SCAN_DEPTH:
|
|
return
|
|
try:
|
|
result = await execute_on_client(action="list_directory", data={"path": path})
|
|
except Exception as exc:
|
|
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
|
return
|
|
for entry in result.get("entries", []):
|
|
entry_path = entry.get("path", "")
|
|
if not entry_path:
|
|
continue
|
|
if entry.get("type") == "directory":
|
|
await _walk(entry_path, depth + 1)
|
|
elif entry.get("type") == "file":
|
|
if ext_set:
|
|
dot_pos = entry_path.rfind(".")
|
|
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
|
if file_ext not in ext_set:
|
|
continue
|
|
all_files.append(entry_path)
|
|
|
|
for root in paths:
|
|
await _walk(root, depth=0)
|
|
|
|
if last_run_at is None:
|
|
return all_files
|
|
|
|
last_run_ms = int(last_run_at.timestamp() * 1000)
|
|
filtered: list[str] = []
|
|
for file_path in all_files:
|
|
try:
|
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
|
modified_at = meta.get("modifiedAt")
|
|
if modified_at is None:
|
|
filtered.append(file_path)
|
|
continue
|
|
if isinstance(modified_at, (int, float)):
|
|
mod_ms = int(modified_at)
|
|
else:
|
|
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
|
if mod_ms > last_run_ms:
|
|
filtered.append(file_path)
|
|
except Exception:
|
|
filtered.append(file_path)
|
|
|
|
return filtered
|
|
|
|
|
|
# ── Code-based entity fetchers ────────────────────────────────────────────
|
|
|
|
|
|
async def _fetch_projects() -> list[dict]:
|
|
try:
|
|
result = await execute_on_client(action="select", table="projects")
|
|
return result.get("rows", [])
|
|
except Exception as exc:
|
|
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
|
return []
|
|
|
|
|
|
_DOMAIN_TABLE: dict[str, str] = {
|
|
"tasks": "tasks",
|
|
"notes": "notes",
|
|
"timelines": "timelines",
|
|
"projects": "projects",
|
|
}
|
|
|
|
|
|
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
|
table = _DOMAIN_TABLE.get(domain)
|
|
if not table:
|
|
return []
|
|
filters: dict[str, Any] = {}
|
|
if project_id != "standalone" and domain != "projects":
|
|
filters["projectId"] = project_id
|
|
try:
|
|
result = await execute_on_client(
|
|
action="select",
|
|
table=table,
|
|
filters=filters if filters else None,
|
|
)
|
|
return result.get("rows", [])
|
|
except Exception as exc:
|
|
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
|
return []
|
|
|
|
|
|
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
|
if not rows:
|
|
return f"No existing {domain}."
|
|
lines: list[str] = []
|
|
for r in rows:
|
|
if domain == "tasks":
|
|
desc = r.get("description") or ""
|
|
desc_part = f" — {desc[:120]}" if desc else ""
|
|
assignee = r.get("assignee") or r.get("assignees") or ""
|
|
due = r.get("dueDate") or r.get("due_date") or ""
|
|
meta = ", ".join(filter(None, [
|
|
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
|
f"assignee: {assignee}" if assignee else "",
|
|
f"due: {due}" if due else "",
|
|
]))
|
|
lines.append(
|
|
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
|
f" ({meta}, id: {r['id']})"
|
|
)
|
|
elif domain == "notes":
|
|
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
|
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
|
lines.append(
|
|
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
|
)
|
|
elif domain == "timelines":
|
|
lines.append(
|
|
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
|
)
|
|
elif domain == "projects":
|
|
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
|
summary_part = f" — {summary}" if summary else ""
|
|
lines.append(
|
|
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
|
f" (id: {r['id']})"
|
|
)
|
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
|
|
|
|
|
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
|
|
|
|
|
async def _classify_file(
|
|
file_path: str,
|
|
file_content: str,
|
|
projects: list[dict],
|
|
config_data_types: list[str],
|
|
langfuse_handler: Any | None = None,
|
|
custom_system_prompt: str | None = None,
|
|
) -> tuple[str, list[str], str | None]:
|
|
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
|
|
|
if not file_content.strip():
|
|
return fallback
|
|
|
|
valid_project_ids = {p["id"] for p in projects}
|
|
|
|
def _fmt_project(p: dict) -> str:
|
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
|
summary_part = f" — {summary[:100]}" if summary else ""
|
|
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
|
|
|
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
|
|
|
domain_definitions = "\n".join(
|
|
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
|
for d in config_data_types
|
|
if d in _DOMAIN_DESCRIPTIONS
|
|
)
|
|
|
|
if custom_system_prompt:
|
|
# Fixture-provided prompt takes absolute priority
|
|
system = custom_system_prompt.format_map(
|
|
{"domain_definitions": domain_definitions, "projects_list": projects_list}
|
|
)
|
|
else:
|
|
system = tracing.compile_prompt(
|
|
"batch_file_classifier",
|
|
fallback=_STEP1_SYSTEM_PROMPT,
|
|
variables={
|
|
"domain_definitions": domain_definitions,
|
|
"projects_list": projects_list,
|
|
},
|
|
)
|
|
|
|
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
|
|
try:
|
|
response = await llm.ainvoke([
|
|
SystemMessage(content=system),
|
|
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
|
])
|
|
raw = _as_text(response.content).strip()
|
|
if raw.startswith("```"):
|
|
raw = raw.split("```")[1]
|
|
if raw.startswith("json"):
|
|
raw = raw[4:]
|
|
parsed = json.loads(raw.strip())
|
|
raw_project_id: str = str(parsed.get("project_id") or "new")
|
|
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
|
new_project_name: str | None = (
|
|
str(parsed["new_project_name"]).strip() or None
|
|
if project_id == "new" and parsed.get("new_project_name")
|
|
else None
|
|
)
|
|
domains: list[str] = [
|
|
d for d in parsed.get("domains", [])
|
|
if d in config_data_types
|
|
]
|
|
if not domains:
|
|
domains = list(config_data_types)
|
|
return project_id, domains, new_project_name
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
|
)
|
|
return fallback
|
|
|
|
|
|
# ── Local agent runner (two-step per file) ────────────────────────────────
|
|
|
|
|
|
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
|
|
"""Execute a local directory agent run.
|
|
|
|
In the microservice world, trigger_data is a serialized dict from
|
|
the REST route (forwarded via Redis), containing the agent config
|
|
fields and run_context.
|
|
|
|
set_current_user() must be called BEFORE this function.
|
|
"""
|
|
run_context: dict = trigger_data.get("run_context", {})
|
|
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
|
run_id = run_context.get("run_id")
|
|
|
|
_running_agents.add(agent_id)
|
|
|
|
# Extract config from trigger payload
|
|
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
|
if not directory_paths:
|
|
directory = trigger_data.get("directory", "")
|
|
if directory:
|
|
directory_paths = [directory]
|
|
|
|
data_types: list[str] = trigger_data.get("data_types", [])
|
|
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
|
prompt_template: str = trigger_data.get("prompt_template", "")
|
|
last_run_at_raw = trigger_data.get("last_run_at")
|
|
last_run_at: datetime | None = None
|
|
if last_run_at_raw:
|
|
if isinstance(last_run_at_raw, str):
|
|
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
|
elif isinstance(last_run_at_raw, (int, float)):
|
|
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
|
|
|
errors: list[str] = []
|
|
items_processed = 0
|
|
items_created = 0
|
|
|
|
custom_section = (
|
|
f"User instructions:\n{prompt_template}"
|
|
if prompt_template
|
|
else ""
|
|
)
|
|
|
|
# Create or load run log
|
|
run_log_id = run_id
|
|
if not run_log_id:
|
|
async with async_session() as db:
|
|
run_log = AgentRunLog(
|
|
agent_id=agent_id,
|
|
agent_type="local",
|
|
user_id=user_id,
|
|
status="running",
|
|
)
|
|
db.add(run_log)
|
|
await db.commit()
|
|
await db.refresh(run_log)
|
|
run_log_id = run_log.id
|
|
|
|
try:
|
|
# ── Scan directories ─────────────────────────────────────────
|
|
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
|
file_paths = await _scan_directories(
|
|
paths=directory_paths,
|
|
extensions=file_extensions,
|
|
last_run_at=last_run_at,
|
|
)
|
|
logger.info(
|
|
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
|
)
|
|
|
|
if not file_paths:
|
|
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
|
return
|
|
|
|
# ── Fetch all projects once ──────────────────────────────────
|
|
projects = await _fetch_projects()
|
|
|
|
for file_path in file_paths:
|
|
try:
|
|
file_result = await execute_on_client(
|
|
action="read_file_content", data={"path": file_path}
|
|
)
|
|
file_content: str = file_result.get("content", "")
|
|
if not file_content:
|
|
continue
|
|
|
|
items_processed += 1
|
|
|
|
# Step 1 — classify file
|
|
project_id, domains, new_project_name = await _classify_file(
|
|
file_path=file_path,
|
|
file_content=file_content,
|
|
projects=projects,
|
|
config_data_types=data_types,
|
|
langfuse_handler=langfuse_handler,
|
|
)
|
|
|
|
# Step 2 — resolve project_id, fetch entities, process
|
|
if project_id == "new":
|
|
proj_name = new_project_name or "Untitled Project"
|
|
try:
|
|
proj_result = await execute_on_client(
|
|
action="insert",
|
|
table="projects",
|
|
data={"name": proj_name, "clientId": None},
|
|
)
|
|
created = proj_result.get("row", {})
|
|
effective_project_id = created.get("id", "standalone")
|
|
if "id" in created:
|
|
projects.append(created)
|
|
except Exception as exc:
|
|
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
|
effective_project_id = "standalone"
|
|
proj_name = "unknown"
|
|
project_context = (
|
|
f"Project: {proj_name} (id: {effective_project_id}). "
|
|
"Always set projectId to this id on every record you create."
|
|
)
|
|
else:
|
|
effective_project_id = project_id
|
|
proj = next((p for p in projects if p["id"] == project_id), None)
|
|
proj_name = proj.get("name", project_id) if proj else project_id
|
|
project_context = (
|
|
f"Project: {proj_name} (id: {project_id}). "
|
|
"Always set projectId to this id on every record you create."
|
|
)
|
|
|
|
domains = [d for d in domains if d != "projects"]
|
|
|
|
existing_blocks: list[str] = []
|
|
for domain in domains:
|
|
rows = await _fetch_domain_entities(domain, effective_project_id)
|
|
existing_blocks.append(_format_entities_for_context(domain, rows))
|
|
|
|
existing_context = "\n\n".join(existing_blocks)
|
|
|
|
system_prompt = tracing.compile_prompt(
|
|
"batch_processing",
|
|
fallback=_PROCESSING_SYSTEM_PROMPT,
|
|
variables={
|
|
"existing_context": existing_context,
|
|
"project_context": project_context,
|
|
"data_types": ", ".join(domains),
|
|
"custom_prompt_section": custom_section,
|
|
},
|
|
)
|
|
|
|
processing_tools = _build_processing_tools(domains)
|
|
|
|
result_text = await _run_agent_with_tools(
|
|
system_prompt=system_prompt,
|
|
user_message=(
|
|
f"Process this file and extract relevant information.\n\n"
|
|
f"File: {file_path}\n\nContent:\n{file_content}"
|
|
),
|
|
tools=processing_tools,
|
|
max_steps=_MAX_PROCESSING_STEPS,
|
|
langfuse_handler=langfuse_handler,
|
|
)
|
|
logger.info(
|
|
"agent_runner: run=%s file=%r result=%s",
|
|
run_log_id, file_path, result_text[:200],
|
|
)
|
|
|
|
except Exception as exc:
|
|
errors.append(f"Error processing '{file_path}': {exc}")
|
|
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
|
|
|
except Exception as exc:
|
|
errors.append(f"Agent run failed: {exc}")
|
|
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
|
finally:
|
|
_running_agents.discard(agent_id)
|
|
|
|
# ── Finalise ────────────────────────────────────────────────────
|
|
if errors and items_processed == 0:
|
|
final_status = "error"
|
|
elif errors:
|
|
final_status = "partial"
|
|
else:
|
|
final_status = "success"
|
|
|
|
await _finalize_run(
|
|
run_log_id,
|
|
status=final_status,
|
|
items_processed=items_processed,
|
|
items_created=items_created,
|
|
errors=errors,
|
|
)
|
|
|
|
# Notify Electron that the run is complete via Redis
|
|
if run_context:
|
|
try:
|
|
channel = ws_out_channel(user_id)
|
|
await redis_client.publish(channel, json.dumps({
|
|
"type": "run_complete",
|
|
"run_context": run_context,
|
|
"status": final_status,
|
|
}))
|
|
except Exception as exc:
|
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
|
|
|
|
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
|
|
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
|
|
|
|
|
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
|
|
"""Execute a cloud connector agent run.
|
|
|
|
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
|
messages from the provider, and runs LLM extraction.
|
|
|
|
set_current_user() must be called BEFORE this function.
|
|
"""
|
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
|
|
|
async with async_session() as db:
|
|
result = await db.execute(
|
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
if config is None:
|
|
logger.error("agent_runner: cloud config %s not found", config_id)
|
|
return
|
|
|
|
# Create run log
|
|
run_log = AgentRunLog(
|
|
agent_id=config.id,
|
|
agent_type="cloud",
|
|
user_id=user_id,
|
|
status="running",
|
|
)
|
|
db.add(run_log)
|
|
await db.commit()
|
|
await db.refresh(run_log)
|
|
run_log_id = run_log.id
|
|
|
|
# ── Decrypt OAuth token ────────────────────────────────────────
|
|
if not config.oauth_token_encrypted:
|
|
await _finalize_run(
|
|
run_log_id,
|
|
status="error",
|
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
|
)
|
|
return
|
|
|
|
try:
|
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
|
except ValueError as exc:
|
|
await _finalize_run(
|
|
run_log_id,
|
|
status="error",
|
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
|
)
|
|
return
|
|
|
|
# ── Instantiate provider ──────────────────────────────────────
|
|
try:
|
|
provider = get_provider(config.provider, credentials_info)
|
|
except ValueError as exc:
|
|
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
|
return
|
|
|
|
# ── Fetch messages ────────────────────────────────────────────
|
|
since: datetime | None = config.last_run_at
|
|
if since is None:
|
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
|
if since.tzinfo is None:
|
|
since = since.replace(tzinfo=timezone.utc)
|
|
|
|
errors: list[str] = []
|
|
items_processed = 0
|
|
|
|
try:
|
|
if config.provider == "gmail":
|
|
raw_messages = await provider.fetch_messages(
|
|
filter_config=config.filter_config,
|
|
since=since,
|
|
)
|
|
elif config.provider == "outlook":
|
|
raw_messages = await provider.fetch_emails(
|
|
filter_config=config.filter_config,
|
|
since=since,
|
|
)
|
|
elif config.provider == "teams":
|
|
raw_messages = await provider.fetch_messages(
|
|
filter_config=config.filter_config,
|
|
since=since,
|
|
)
|
|
else:
|
|
raw_messages = []
|
|
except RuntimeError as exc:
|
|
await _finalize_run(
|
|
run_log_id,
|
|
status="error",
|
|
errors=[f"Provider fetch failed: {exc}"],
|
|
update_config_last_run=True,
|
|
config_id=config.id,
|
|
config_type="cloud",
|
|
)
|
|
return
|
|
|
|
logger.info(
|
|
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
|
config.id, len(raw_messages), config.provider,
|
|
)
|
|
|
|
# ── Extract + insert via LLM ─────────────────────────────────
|
|
try:
|
|
processing_tools = _build_processing_tools(config.data_types)
|
|
custom_section = (
|
|
f"User instructions:\n{config.prompt_template}"
|
|
if config.prompt_template
|
|
else ""
|
|
)
|
|
|
|
for msg in raw_messages:
|
|
content_text = msg.as_text
|
|
if not content_text:
|
|
continue
|
|
items_processed += 1
|
|
|
|
processing_prompt = tracing.compile_prompt(
|
|
"batch_cloud_processing",
|
|
fallback=_CLOUD_PROCESSING_PROMPT,
|
|
variables={
|
|
"data_types": ", ".join(config.data_types),
|
|
"project_context": "Determine the appropriate project from the message context.",
|
|
"file_list": f"Message from {config.provider} (id: {msg.id})",
|
|
"custom_prompt_section": custom_section,
|
|
},
|
|
)
|
|
|
|
try:
|
|
await _run_agent_with_tools(
|
|
system_prompt=processing_prompt,
|
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
|
tools=processing_tools,
|
|
max_steps=_MAX_PROCESSING_STEPS,
|
|
langfuse_handler=langfuse_handler,
|
|
)
|
|
except Exception as exc:
|
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
|
except Exception as exc:
|
|
errors.append(f"Agent run failed: {exc}")
|
|
|
|
# ── Persist refreshed token ───────────────────────────────────
|
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
|
if refreshed:
|
|
try:
|
|
new_encrypted = encrypt_token(refreshed)
|
|
async with async_session() as db:
|
|
cfg_result = await db.execute(
|
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
|
)
|
|
cfg_row = cfg_result.scalar_one_or_none()
|
|
if cfg_row:
|
|
cfg_row.oauth_token_encrypted = new_encrypted
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
|
|
|
# ── Finalise ──────────────────────────────────────────────────
|
|
if errors and items_processed == 0:
|
|
final_status = "error"
|
|
elif errors:
|
|
final_status = "partial"
|
|
else:
|
|
final_status = "success"
|
|
|
|
await _finalize_run(
|
|
run_log_id,
|
|
status=final_status,
|
|
items_processed=items_processed,
|
|
items_created=0,
|
|
errors=errors,
|
|
update_config_last_run=True,
|
|
config_id=config.id,
|
|
config_type="cloud",
|
|
)
|
|
|
|
|
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
|
|
|
|
|
async def _finalize_run(
|
|
run_log_id: int | str,
|
|
*,
|
|
status: str,
|
|
items_processed: int = 0,
|
|
items_created: int = 0,
|
|
errors: list[str] | None = None,
|
|
update_config_last_run: bool = False,
|
|
config_id: str | None = None,
|
|
config_type: str | None = None,
|
|
) -> None:
|
|
"""Persist the run outcome and optionally update last_run_at on the config."""
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
async with async_session() as db:
|
|
result = await db.execute(
|
|
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
|
)
|
|
managed = result.scalar_one_or_none()
|
|
if managed is None:
|
|
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
|
return
|
|
|
|
managed.status = status
|
|
managed.items_processed = items_processed
|
|
managed.items_created = items_created
|
|
managed.errors = errors or []
|
|
managed.completed_at = now
|
|
|
|
if update_config_last_run and config_id:
|
|
if config_type == "local":
|
|
cfg_result = await db.execute(
|
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
|
)
|
|
cfg = cfg_result.scalar_one_or_none()
|
|
if cfg:
|
|
cfg.last_run_at = now
|
|
elif config_type == "cloud":
|
|
cfg_result = await db.execute(
|
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
|
)
|
|
cfg = cfg_result.scalar_one_or_none()
|
|
if cfg:
|
|
cfg.last_run_at = now
|
|
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|