"""Agent run orchestrator — adapted for Batch Agent Service. Key changes from monolith app/core/agent_runner.py: - No DeviceConnectionManager — tool calls go through Redis ws_context. - set_current_user / clear_current_user replace set_client_executor. - run_local_agent accepts a serialized dict (from Redis / REST) instead of SQLAlchemy model objects. - _finalize_run writes to PostgreSQL via shared.db.async_session. - Cloud agent import path changed to app.integrations. """ from __future__ import annotations import asyncio import json import logging import uuid from datetime import datetime, timedelta, timezone from typing import Any from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from sqlalchemy import select from app.agents.filesystem_agent import FILESYSTEM_TOOLS from shared.agents.note_agent import NOTE_TOOLS from shared.agents.project_agent import PROJECT_TOOLS from shared.agents.task_agent import TASK_TOOLS from shared.agents.timeline_agent import TIMELINE_TOOLS from shared.llm import get_llm from shared.ws_context import execute_on_client, set_current_user, clear_current_user import app.tracing as tracing from shared.db import async_session from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig from shared.redis import redis_client, ws_out_channel logger = logging.getLogger(__name__) # ── Concurrency guard ───────────────────────────────────────────────────── _running_agents: set[str] = set() def is_agent_running(agent_id: str) -> bool: return agent_id in _running_agents # ── Timeouts ─────────────────────────────────────────────────────────────── _TOOL_CALL_TIMEOUT: int = 30 _MAX_PROCESSING_STEPS: int = 12 _MAX_SCAN_DEPTH: int = 5 # ── Data-type to tool mapping ───────────────────────────────────────────── _DATA_TYPE_TOOLS: dict[str, list[Any]] = { "tasks": TASK_TOOLS, "notes": NOTE_TOOLS, "timelines": TIMELINE_TOOLS, } # ── Step 1: Classification prompt ───────────────────────────────────────── _DOMAIN_DESCRIPTIONS: dict[str, str] = { "tasks": ( "Action items, to-dos, deliverables — anything that describes work to be done, " "assigned to someone, or tracked with a due date or status." ), "notes": ( "Documentation, meeting notes, summaries, reference material — " "written content meant to be read and referenced rather than acted on." ), "timelines": ( "Project milestones, deadlines, scheduled events — " "specific dates that mark a point in the progress of a project." ), "projects": ( "High-level project entities — only relevant if the file clearly introduces " "a new project or updates the scope of an existing one." ), } _STEP1_SYSTEM_PROMPT = """\ You are a file classifier for a freelance project management tool. Your job is to match a file to an existing project and identify which data domains to extract. ## Project matching rules (STRICT — follow in order) 1. Search the file content for any mention of a project name, client name, acronym, or topic that overlaps with the existing projects listed below. 2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough. 3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort when the file has zero meaningful connection to any listed project. 4. When in doubt, pick the closest match from the list. ## Response format Respond ONLY with a JSON object — no markdown, no explanation: {{"project_id": "", "new_project_name": "", "domains": ["tasks", "notes"]}} ## Domain definitions (only consider domains in the allowed list) {domain_definitions} ## Existing projects {projects_list} """ # ── Step 2: Processing prompt ───────────────────────────────────────────── _PROCESSING_SYSTEM_PROMPT = """\ You are a data extraction assistant for a freelance project management tool. Your task: extract structured data from the file content and persist it using the available tools. ## Mandatory process — follow this order for EVERY item you extract 1. READ the existing records listed below for the relevant domain. 2. SEARCH for a match by title, topic, or semantic similarity. 3. If a match exists → call the update_* tool with the existing record's id. 4. If no match exists → call the create_* tool and set isAiSuggested=1. NEVER call create_* without first checking the existing records. NEVER duplicate a record that already exists under a different wording. ## Existing records (source of truth) {existing_context} ## Context Project: {project_context} Domains to extract: {data_types} {custom_prompt_section} """ # ── Cloud processing prompt ─────────────────────────────────────────────── _CLOUD_PROCESSING_PROMPT = """\ You are a data extraction and management assistant for a freelance project management tool. Available tools: Filesystem : read_file_content, list_directory, get_file_metadata Tasks : list_tasks, create_task, update_task, add_task_comment Notes : list_notes, get_note, create_note, update_note Timelines : list_timelines, create_timeline, update_timeline Projects : list_all_projects, get_project, create_project, update_project Your task: 1. Read the full content of each file below using read_file_content. 2. For each piece of information found, ALWAYS try to match and update an existing record before creating a new one. 3. ONLY act on these entity types: {data_types}. 4. Do NOT invent data. Only extract what is clearly present in the files. 5. If a file contains no relevant data for the target entity types, skip it. {project_context} Files to process: {file_list} {custom_prompt_section} After processing all files, respond with a brief summary of what you updated and what you created. """ # ── LLM tool-calling loop ───────────────────────────────────────────────── def _as_text(content: Any) -> str: if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, str): parts.append(item) elif isinstance(item, dict): text = item.get("text") if isinstance(text, str): parts.append(text) return "".join(parts) return str(content) async def _run_agent_with_tools( *, system_prompt: str, user_message: str, tools: list[Any], max_steps: int, langfuse_handler: Any | None = None, ) -> str: """Run an LLM agent with tool-calling, returning the final text response.""" callbacks = [langfuse_handler] if langfuse_handler else None llm = get_llm(callbacks=callbacks) llm_with_tools = llm.bind_tools(tools) messages: list[Any] = [ SystemMessage(content=system_prompt), HumanMessage(content=user_message), ] tool_map = {tool_def.name: tool_def for tool_def in tools} for _ in range(max_steps): response: AIMessage = await llm_with_tools.ainvoke(messages) messages.append(response) if not response.tool_calls: return _as_text(response.content) for call in response.tool_calls: call_id = str(call.get("id", "")) call_name = str(call.get("name", "")) call_args = call.get("args", {}) logger.info( "agent_runner: tool_call name=%s args=%s", call_name, json.dumps(call_args, ensure_ascii=True)[:800], ) tool_fn = tool_map.get(call_name) if tool_fn is None: tool_output = f"Unknown tool: {call_name}" else: tool_output = await tool_fn.ainvoke(call_args) logger.info( "agent_runner: tool_result name=%s output=%s", call_name, str(tool_output)[:200], ) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) final = await llm.ainvoke(messages) return _as_text(final.content) # ── Tool list builder ───────────────────────────────────────────────────── def _build_processing_tools(data_types: list[str]) -> list[Any]: tools: list[Any] = list(FILESYSTEM_TOOLS) for dt in data_types: dt_tools = _DATA_TYPE_TOOLS.get(dt) if dt_tools: tools.extend(dt_tools) return tools # ── Code-based directory scanner ───────────────────────────────────────── async def _scan_directories( paths: list[str], extensions: list[str], last_run_at: datetime | None, ) -> list[str]: all_files: list[str] = [] ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set() async def _walk(path: str, depth: int) -> None: if depth > _MAX_SCAN_DEPTH: return try: result = await execute_on_client(action="list_directory", data={"path": path}) except Exception as exc: logger.warning("agent_runner: list_directory failed %r: %s", path, exc) return for entry in result.get("entries", []): entry_path = entry.get("path", "") if not entry_path: continue if entry.get("type") == "directory": await _walk(entry_path, depth + 1) elif entry.get("type") == "file": if ext_set: dot_pos = entry_path.rfind(".") file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else "" if file_ext not in ext_set: continue all_files.append(entry_path) for root in paths: await _walk(root, depth=0) if last_run_at is None: return all_files last_run_ms = int(last_run_at.timestamp() * 1000) filtered: list[str] = [] for file_path in all_files: try: meta = await execute_on_client(action="get_file_metadata", data={"path": file_path}) modified_at = meta.get("modifiedAt") if modified_at is None: filtered.append(file_path) continue if isinstance(modified_at, (int, float)): mod_ms = int(modified_at) else: mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000) if mod_ms > last_run_ms: filtered.append(file_path) except Exception: filtered.append(file_path) return filtered # ── Code-based entity fetchers ──────────────────────────────────────────── async def _fetch_projects() -> list[dict]: try: result = await execute_on_client(action="select", table="projects") return result.get("rows", []) except Exception as exc: logger.warning("agent_runner: failed to fetch projects: %s", exc) return [] _DOMAIN_TABLE: dict[str, str] = { "tasks": "tasks", "notes": "notes", "timelines": "timelines", "projects": "projects", } async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]: table = _DOMAIN_TABLE.get(domain) if not table: return [] filters: dict[str, Any] = {} if project_id != "standalone" and domain != "projects": filters["projectId"] = project_id try: result = await execute_on_client( action="select", table=table, filters=filters if filters else None, ) return result.get("rows", []) except Exception as exc: logger.warning("agent_runner: failed to fetch %s: %s", domain, exc) return [] def _format_entities_for_context(domain: str, rows: list[dict]) -> str: if not rows: return f"No existing {domain}." lines: list[str] = [] for r in rows: if domain == "tasks": desc = r.get("description") or "" desc_part = f" — {desc[:120]}" if desc else "" assignee = r.get("assignee") or r.get("assignees") or "" due = r.get("dueDate") or r.get("due_date") or "" meta = ", ".join(filter(None, [ f"priority: {r.get('priority', '')}" if r.get("priority") else "", f"assignee: {assignee}" if assignee else "", f"due: {due}" if due else "", ])) lines.append( f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}" f" ({meta}, id: {r['id']})" ) elif domain == "notes": snippet = (r.get("content") or "")[:200].replace("\n", " ") snippet_part = f"\n Preview: {snippet}" if snippet else "" lines.append( f" - {r.get('title', '')} (id: {r['id']}){snippet_part}" ) elif domain == "timelines": lines.append( f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})" ) elif domain == "projects": summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120] summary_part = f" — {summary}" if summary else "" lines.append( f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}" f" (id: {r['id']})" ) return f"Existing {domain}:\n" + "\n".join(lines) # ── Step 1: LLM file classifier ─────────────────────────────────────────── async def _classify_file( file_path: str, file_content: str, projects: list[dict], config_data_types: list[str], langfuse_handler: Any | None = None, custom_system_prompt: str | None = None, ) -> tuple[str, list[str], str | None]: fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None) if not file_content.strip(): return fallback valid_project_ids = {p["id"] for p in projects} def _fmt_project(p: dict) -> str: summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip() summary_part = f" — {summary[:100]}" if summary else "" return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}" projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)" domain_definitions = "\n".join( f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}" for d in config_data_types if d in _DOMAIN_DESCRIPTIONS ) if custom_system_prompt: # Fixture-provided prompt takes absolute priority system = custom_system_prompt.format_map( {"domain_definitions": domain_definitions, "projects_list": projects_list} ) else: system = tracing.compile_prompt( "batch_file_classifier", fallback=_STEP1_SYSTEM_PROMPT, variables={ "domain_definitions": domain_definitions, "projects_list": projects_list, }, ) llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None) try: response = await llm.ainvoke([ SystemMessage(content=system), HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"), ]) raw = _as_text(response.content).strip() if raw.startswith("```"): raw = raw.split("```")[1] if raw.startswith("json"): raw = raw[4:] parsed = json.loads(raw.strip()) raw_project_id: str = str(parsed.get("project_id") or "new") project_id = raw_project_id if raw_project_id in valid_project_ids else "new" new_project_name: str | None = ( str(parsed["new_project_name"]).strip() or None if project_id == "new" and parsed.get("new_project_name") else None ) domains: list[str] = [ d for d in parsed.get("domains", []) if d in config_data_types ] if not domains: domains = list(config_data_types) return project_id, domains, new_project_name except Exception as exc: logger.warning( "agent_runner: step1 classification failed for %r: %s", file_path, exc ) return fallback # ── Local agent runner (two-step per file) ──────────────────────────────── async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None: """Execute a local directory agent run. In the microservice world, trigger_data is a serialized dict from the REST route (forwarded via Redis), containing the agent config fields and run_context. set_current_user() must be called BEFORE this function. """ run_context: dict = trigger_data.get("run_context", {}) agent_id = run_context.get("agent_id", str(uuid.uuid4())) run_id = run_context.get("run_id") _running_agents.add(agent_id) # Extract config from trigger payload directory_paths: list[str] = trigger_data.get("directory_paths", []) if not directory_paths: directory = trigger_data.get("directory", "") if directory: directory_paths = [directory] data_types: list[str] = trigger_data.get("data_types", []) file_extensions: list[str] = trigger_data.get("file_extensions", []) prompt_template: str = trigger_data.get("prompt_template", "") last_run_at_raw = trigger_data.get("last_run_at") last_run_at: datetime | None = None if last_run_at_raw: if isinstance(last_run_at_raw, str): last_run_at = datetime.fromisoformat(last_run_at_raw) elif isinstance(last_run_at_raw, (int, float)): last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc) errors: list[str] = [] items_processed = 0 items_created = 0 custom_section = ( f"User instructions:\n{prompt_template}" if prompt_template else "" ) # Create or load run log run_log_id = run_id if not run_log_id: async with async_session() as db: run_log = AgentRunLog( agent_id=agent_id, agent_type="local", user_id=user_id, status="running", ) db.add(run_log) await db.commit() await db.refresh(run_log) run_log_id = run_log.id try: # ── Scan directories ───────────────────────────────────────── logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id) file_paths = await _scan_directories( paths=directory_paths, extensions=file_extensions, last_run_at=last_run_at, ) logger.info( "agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths) ) if not file_paths: await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0) return # ── Fetch all projects once ────────────────────────────────── projects = await _fetch_projects() for file_path in file_paths: try: file_result = await execute_on_client( action="read_file_content", data={"path": file_path} ) file_content: str = file_result.get("content", "") if not file_content: continue items_processed += 1 # Step 1 — classify file project_id, domains, new_project_name = await _classify_file( file_path=file_path, file_content=file_content, projects=projects, config_data_types=data_types, langfuse_handler=langfuse_handler, ) # Step 2 — resolve project_id, fetch entities, process if project_id == "new": proj_name = new_project_name or "Untitled Project" try: proj_result = await execute_on_client( action="insert", table="projects", data={"name": proj_name, "clientId": None}, ) created = proj_result.get("row", {}) effective_project_id = created.get("id", "standalone") if "id" in created: projects.append(created) except Exception as exc: logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc) effective_project_id = "standalone" proj_name = "unknown" project_context = ( f"Project: {proj_name} (id: {effective_project_id}). " "Always set projectId to this id on every record you create." ) else: effective_project_id = project_id proj = next((p for p in projects if p["id"] == project_id), None) proj_name = proj.get("name", project_id) if proj else project_id project_context = ( f"Project: {proj_name} (id: {project_id}). " "Always set projectId to this id on every record you create." ) domains = [d for d in domains if d != "projects"] existing_blocks: list[str] = [] for domain in domains: rows = await _fetch_domain_entities(domain, effective_project_id) existing_blocks.append(_format_entities_for_context(domain, rows)) existing_context = "\n\n".join(existing_blocks) system_prompt = tracing.compile_prompt( "batch_processing", fallback=_PROCESSING_SYSTEM_PROMPT, variables={ "existing_context": existing_context, "project_context": project_context, "data_types": ", ".join(domains), "custom_prompt_section": custom_section, }, ) processing_tools = _build_processing_tools(domains) result_text = await _run_agent_with_tools( system_prompt=system_prompt, user_message=( f"Process this file and extract relevant information.\n\n" f"File: {file_path}\n\nContent:\n{file_content}" ), tools=processing_tools, max_steps=_MAX_PROCESSING_STEPS, langfuse_handler=langfuse_handler, ) logger.info( "agent_runner: run=%s file=%r result=%s", run_log_id, file_path, result_text[:200], ) except Exception as exc: errors.append(f"Error processing '{file_path}': {exc}") logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc) except Exception as exc: errors.append(f"Agent run failed: {exc}") logger.error("agent_runner: run=%s failed: %s", run_log_id, exc) finally: _running_agents.discard(agent_id) # ── Finalise ──────────────────────────────────────────────────── if errors and items_processed == 0: final_status = "error" elif errors: final_status = "partial" else: final_status = "success" await _finalize_run( run_log_id, status=final_status, items_processed=items_processed, items_created=items_created, errors=errors, ) # Notify Electron that the run is complete via Redis if run_context: try: channel = ws_out_channel(user_id) await redis_client.publish(channel, json.dumps({ "type": "run_complete", "run_context": run_context, "status": final_status, })) except Exception as exc: logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc) # ── Cloud agent runner ───────────────────────────────────────────────────── _CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7 async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None: """Execute a cloud connector agent run. Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches messages from the provider, and runs LLM extraction. set_current_user() must be called BEFORE this function. """ from app.integrations import decrypt_token, encrypt_token, get_provider async with async_session() as db: result = await db.execute( select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) ) config = result.scalar_one_or_none() if config is None: logger.error("agent_runner: cloud config %s not found", config_id) return # Create run log run_log = AgentRunLog( agent_id=config.id, agent_type="cloud", user_id=user_id, status="running", ) db.add(run_log) await db.commit() await db.refresh(run_log) run_log_id = run_log.id # ── Decrypt OAuth token ──────────────────────────────────────── if not config.oauth_token_encrypted: await _finalize_run( run_log_id, status="error", errors=[f"No OAuth token stored for cloud agent '{config.name}'"], ) return try: credentials_info = decrypt_token(config.oauth_token_encrypted) except ValueError as exc: await _finalize_run( run_log_id, status="error", errors=[f"Failed to decrypt OAuth token: {exc}"], ) return # ── Instantiate provider ────────────────────────────────────── try: provider = get_provider(config.provider, credentials_info) except ValueError as exc: await _finalize_run(run_log_id, status="error", errors=[str(exc)]) return # ── Fetch messages ──────────────────────────────────────────── since: datetime | None = config.last_run_at if since is None: since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS) if since.tzinfo is None: since = since.replace(tzinfo=timezone.utc) errors: list[str] = [] items_processed = 0 try: if config.provider == "gmail": raw_messages = await provider.fetch_messages( filter_config=config.filter_config, since=since, ) elif config.provider == "outlook": raw_messages = await provider.fetch_emails( filter_config=config.filter_config, since=since, ) elif config.provider == "teams": raw_messages = await provider.fetch_messages( filter_config=config.filter_config, since=since, ) else: raw_messages = [] except RuntimeError as exc: await _finalize_run( run_log_id, status="error", errors=[f"Provider fetch failed: {exc}"], update_config_last_run=True, config_id=config.id, config_type="cloud", ) return logger.info( "agent_runner: cloud agent %s fetched %d item(s) from %s", config.id, len(raw_messages), config.provider, ) # ── Extract + insert via LLM ───────────────────────────────── try: processing_tools = _build_processing_tools(config.data_types) custom_section = ( f"User instructions:\n{config.prompt_template}" if config.prompt_template else "" ) for msg in raw_messages: content_text = msg.as_text if not content_text: continue items_processed += 1 processing_prompt = tracing.compile_prompt( "batch_cloud_processing", fallback=_CLOUD_PROCESSING_PROMPT, variables={ "data_types": ", ".join(config.data_types), "project_context": "Determine the appropriate project from the message context.", "file_list": f"Message from {config.provider} (id: {msg.id})", "custom_prompt_section": custom_section, }, ) try: await _run_agent_with_tools( system_prompt=processing_prompt, user_message=f"Process this message content:\n\n{content_text[:8000]}", tools=processing_tools, max_steps=_MAX_PROCESSING_STEPS, langfuse_handler=langfuse_handler, ) except Exception as exc: errors.append(f"LLM processing error for message {msg.id!r}: {exc}") except Exception as exc: errors.append(f"Agent run failed: {exc}") # ── Persist refreshed token ─────────────────────────────────── refreshed = getattr(provider, "refreshed_credentials", None) if refreshed: try: new_encrypted = encrypt_token(refreshed) async with async_session() as db: cfg_result = await db.execute( select(CloudAgentConfig).where(CloudAgentConfig.id == config.id) ) cfg_row = cfg_result.scalar_one_or_none() if cfg_row: cfg_row.oauth_token_encrypted = new_encrypted await db.commit() except Exception as exc: logger.warning("agent_runner: failed to persist refreshed token: %s", exc) # ── Finalise ────────────────────────────────────────────────── if errors and items_processed == 0: final_status = "error" elif errors: final_status = "partial" else: final_status = "success" await _finalize_run( run_log_id, status=final_status, items_processed=items_processed, items_created=0, errors=errors, update_config_last_run=True, config_id=config.id, config_type="cloud", ) # ── Internal helper ───────────────────────────────────────────────────────── async def _finalize_run( run_log_id: int | str, *, status: str, items_processed: int = 0, items_created: int = 0, errors: list[str] | None = None, update_config_last_run: bool = False, config_id: str | None = None, config_type: str | None = None, ) -> None: """Persist the run outcome and optionally update last_run_at on the config.""" now = datetime.now(timezone.utc) try: async with async_session() as db: result = await db.execute( select(AgentRunLog).where(AgentRunLog.id == run_log_id) ) managed = result.scalar_one_or_none() if managed is None: logger.warning("agent_runner: run_log %s not found for finalization", run_log_id) return managed.status = status managed.items_processed = items_processed managed.items_created = items_created managed.errors = errors or [] managed.completed_at = now if update_config_last_run and config_id: if config_type == "local": cfg_result = await db.execute( select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) ) cfg = cfg_result.scalar_one_or_none() if cfg: cfg.last_run_at = now elif config_type == "cloud": cfg_result = await db.execute( select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) ) cfg = cfg_result.scalar_one_or_none() if cfg: cfg.last_run_at = now await db.commit() except Exception as exc: logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)