diff --git a/services/batch-agent/Dockerfile b/services/batch-agent/Dockerfile new file mode 100644 index 0000000..1604b12 --- /dev/null +++ b/services/batch-agent/Dockerfile @@ -0,0 +1,36 @@ +# ── builder ────────────────────────────────────────────────────────────────── +FROM python:3.12-slim AS builder + +WORKDIR /build + +COPY services/batch-agent/requirements.txt ./requirements.txt +RUN pip install --upgrade pip && \ + pip install --no-cache-dir --prefix=/install -r requirements.txt + +# ── runtime ────────────────────────────────────────────────────────────────── +FROM python:3.12-slim AS runtime + +RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser + +WORKDIR /app + +COPY --from=builder /install /usr/local + +# Shared module +COPY shared/ shared/ + +# Service source +COPY services/batch-agent/app/ app/ + +RUN chown -R appuser:appgroup /app + +USER appuser + +EXPOSE 8000 + +# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s) +CMD ["gunicorn", "app.main:app", \ + "-k", "uvicorn.workers.UvicornWorker", \ + "--bind", "0.0.0.0:8000", \ + "--workers", "2", \ + "--timeout", "300"] diff --git a/services/batch-agent/app/agent_runner.py b/services/batch-agent/app/agent_runner.py new file mode 100644 index 0000000..c8c40fa --- /dev/null +++ b/services/batch-agent/app/agent_runner.py @@ -0,0 +1,884 @@ +"""Agent run orchestrator — adapted for Batch Agent Service. + +Key changes from monolith app/core/agent_runner.py: + - No DeviceConnectionManager — tool calls go through Redis ws_context. + - set_current_user / clear_current_user replace set_client_executor. + - run_local_agent accepts a serialized dict (from Redis / REST) instead + of SQLAlchemy model objects. + - _finalize_run writes to PostgreSQL via shared.db.async_session. + - Cloud agent import path changed to app.integrations. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from sqlalchemy import select + +from app.agents.filesystem_agent import FILESYSTEM_TOOLS +from app.agents.note_agent import NOTE_TOOLS +from app.agents.project_agent import PROJECT_TOOLS +from app.agents.task_agent import TASK_TOOLS +from app.agents.timeline_agent import TIMELINE_TOOLS +from app.llm import get_llm +from app.ws_context import execute_on_client, set_current_user, clear_current_user +from shared.db import async_session +from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig +from shared.redis import redis_client, ws_out_channel + +logger = logging.getLogger(__name__) + +# ── Concurrency guard ───────────────────────────────────────────────────── +_running_agents: set[str] = set() + + +def is_agent_running(agent_id: str) -> bool: + return agent_id in _running_agents + + +# ── Timeouts ─────────────────────────────────────────────────────────────── +_TOOL_CALL_TIMEOUT: int = 30 +_MAX_PROCESSING_STEPS: int = 12 +_MAX_SCAN_DEPTH: int = 5 + +# ── Data-type to tool mapping ───────────────────────────────────────────── +_DATA_TYPE_TOOLS: dict[str, list[Any]] = { + "tasks": TASK_TOOLS, + "notes": NOTE_TOOLS, + "timelines": TIMELINE_TOOLS, +} + +# ── Step 1: Classification prompt ───────────────────────────────────────── + +_DOMAIN_DESCRIPTIONS: dict[str, str] = { + "tasks": ( + "Action items, to-dos, deliverables — anything that describes work to be done, " + "assigned to someone, or tracked with a due date or status." + ), + "notes": ( + "Documentation, meeting notes, summaries, reference material — " + "written content meant to be read and referenced rather than acted on." + ), + "timelines": ( + "Project milestones, deadlines, scheduled events — " + "specific dates that mark a point in the progress of a project." + ), + "projects": ( + "High-level project entities — only relevant if the file clearly introduces " + "a new project or updates the scope of an existing one." + ), +} + +_STEP1_SYSTEM_PROMPT = """\ +You are a file classifier for a freelance project management tool. + +Your job is to match a file to an existing project and identify which data domains to extract. + +## Project matching rules (STRICT — follow in order) + +1. Search the file content for any mention of a project name, client name, acronym, or topic + that overlaps with the existing projects listed below. +2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough. +3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort + when the file has zero meaningful connection to any listed project. +4. When in doubt, pick the closest match from the list. + +## Response format + +Respond ONLY with a JSON object — no markdown, no explanation: + +{{"project_id": "", "new_project_name": "", "domains": ["tasks", "notes"]}} + +## Domain definitions (only consider domains in the allowed list) + +{domain_definitions} + +## Existing projects + +{projects_list} +""" + +# ── Step 2: Processing prompt ───────────────────────────────────────────── + +_PROCESSING_SYSTEM_PROMPT = """\ +You are a data extraction assistant for a freelance project management tool. + +Your task: extract structured data from the file content and persist it using the available tools. + +## Mandatory process — follow this order for EVERY item you extract + +1. READ the existing records listed below for the relevant domain. +2. SEARCH for a match by title, topic, or semantic similarity. +3. If a match exists → call the update_* tool with the existing record's id. +4. If no match exists → call the create_* tool and set isAiSuggested=1. + +NEVER call create_* without first checking the existing records. +NEVER duplicate a record that already exists under a different wording. + +## Existing records (source of truth) + +{existing_context} + +## Context + +Project: {project_context} +Domains to extract: {data_types} + +{custom_prompt_section} +""" + +# ── Cloud processing prompt ─────────────────────────────────────────────── + +_CLOUD_PROCESSING_PROMPT = """\ +You are a data extraction and management assistant for a freelance project +management tool. + +Available tools: + Filesystem : read_file_content, list_directory, get_file_metadata + Tasks : list_tasks, create_task, update_task, add_task_comment + Notes : list_notes, get_note, create_note, update_note + Timelines : list_timelines, create_timeline, update_timeline + Projects : list_all_projects, get_project, create_project, update_project + +Your task: +1. Read the full content of each file below using read_file_content. +2. For each piece of information found, ALWAYS try to match and update an + existing record before creating a new one. +3. ONLY act on these entity types: {data_types}. +4. Do NOT invent data. Only extract what is clearly present in the files. +5. If a file contains no relevant data for the target entity types, skip it. + +{project_context} + +Files to process: +{file_list} + +{custom_prompt_section} + +After processing all files, respond with a brief summary of what you updated +and what you created. +""" + + +# ── LLM tool-calling loop ───────────────────────────────────────────────── + + +def _as_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + return str(content) + + +async def _run_agent_with_tools( + *, + system_prompt: str, + user_message: str, + tools: list[Any], + max_steps: int, +) -> str: + """Run an LLM agent with tool-calling, returning the final text response.""" + llm = get_llm() + llm_with_tools = llm.bind_tools(tools) + messages: list[Any] = [ + SystemMessage(content=system_prompt), + HumanMessage(content=user_message), + ] + + tool_map = {tool_def.name: tool_def for tool_def in tools} + + for _ in range(max_steps): + response: AIMessage = await llm_with_tools.ainvoke(messages) + messages.append(response) + + if not response.tool_calls: + return _as_text(response.content) + + for call in response.tool_calls: + call_id = str(call.get("id", "")) + call_name = str(call.get("name", "")) + call_args = call.get("args", {}) + logger.info( + "agent_runner: tool_call name=%s args=%s", + call_name, + json.dumps(call_args, ensure_ascii=True)[:800], + ) + + tool_fn = tool_map.get(call_name) + if tool_fn is None: + tool_output = f"Unknown tool: {call_name}" + else: + tool_output = await tool_fn.ainvoke(call_args) + + logger.info( + "agent_runner: tool_result name=%s output=%s", + call_name, + str(tool_output)[:200], + ) + messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) + + final = await llm.ainvoke(messages) + return _as_text(final.content) + + +# ── Tool list builder ───────────────────────────────────────────────────── + + +def _build_processing_tools(data_types: list[str]) -> list[Any]: + tools: list[Any] = list(FILESYSTEM_TOOLS) + for dt in data_types: + dt_tools = _DATA_TYPE_TOOLS.get(dt) + if dt_tools: + tools.extend(dt_tools) + return tools + + +# ── Code-based directory scanner ───────────────────────────────────────── + + +async def _scan_directories( + paths: list[str], + extensions: list[str], + last_run_at: datetime | None, +) -> list[str]: + all_files: list[str] = [] + ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set() + + async def _walk(path: str, depth: int) -> None: + if depth > _MAX_SCAN_DEPTH: + return + try: + result = await execute_on_client(action="list_directory", data={"path": path}) + except Exception as exc: + logger.warning("agent_runner: list_directory failed %r: %s", path, exc) + return + for entry in result.get("entries", []): + entry_path = entry.get("path", "") + if not entry_path: + continue + if entry.get("type") == "directory": + await _walk(entry_path, depth + 1) + elif entry.get("type") == "file": + if ext_set: + dot_pos = entry_path.rfind(".") + file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else "" + if file_ext not in ext_set: + continue + all_files.append(entry_path) + + for root in paths: + await _walk(root, depth=0) + + if last_run_at is None: + return all_files + + last_run_ms = int(last_run_at.timestamp() * 1000) + filtered: list[str] = [] + for file_path in all_files: + try: + meta = await execute_on_client(action="get_file_metadata", data={"path": file_path}) + modified_at = meta.get("modifiedAt") + if modified_at is None: + filtered.append(file_path) + continue + if isinstance(modified_at, (int, float)): + mod_ms = int(modified_at) + else: + mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000) + if mod_ms > last_run_ms: + filtered.append(file_path) + except Exception: + filtered.append(file_path) + + return filtered + + +# ── Code-based entity fetchers ──────────────────────────────────────────── + + +async def _fetch_projects() -> list[dict]: + try: + result = await execute_on_client(action="select", table="projects") + return result.get("rows", []) + except Exception as exc: + logger.warning("agent_runner: failed to fetch projects: %s", exc) + return [] + + +_DOMAIN_TABLE: dict[str, str] = { + "tasks": "tasks", + "notes": "notes", + "timelines": "timelines", + "projects": "projects", +} + + +async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]: + table = _DOMAIN_TABLE.get(domain) + if not table: + return [] + filters: dict[str, Any] = {} + if project_id != "standalone" and domain != "projects": + filters["projectId"] = project_id + try: + result = await execute_on_client( + action="select", + table=table, + filters=filters if filters else None, + ) + return result.get("rows", []) + except Exception as exc: + logger.warning("agent_runner: failed to fetch %s: %s", domain, exc) + return [] + + +def _format_entities_for_context(domain: str, rows: list[dict]) -> str: + if not rows: + return f"No existing {domain}." + lines: list[str] = [] + for r in rows: + if domain == "tasks": + desc = r.get("description") or "" + desc_part = f" — {desc[:120]}" if desc else "" + assignee = r.get("assignee") or r.get("assignees") or "" + due = r.get("dueDate") or r.get("due_date") or "" + meta = ", ".join(filter(None, [ + f"priority: {r.get('priority', '')}" if r.get("priority") else "", + f"assignee: {assignee}" if assignee else "", + f"due: {due}" if due else "", + ])) + lines.append( + f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}" + f" ({meta}, id: {r['id']})" + ) + elif domain == "notes": + snippet = (r.get("content") or "")[:200].replace("\n", " ") + snippet_part = f"\n Preview: {snippet}" if snippet else "" + lines.append( + f" - {r.get('title', '')} (id: {r['id']}){snippet_part}" + ) + elif domain == "timelines": + lines.append( + f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})" + ) + elif domain == "projects": + summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120] + summary_part = f" — {summary}" if summary else "" + lines.append( + f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}" + f" (id: {r['id']})" + ) + return f"Existing {domain}:\n" + "\n".join(lines) + + +# ── Step 1: LLM file classifier ─────────────────────────────────────────── + + +async def _classify_file( + file_path: str, + file_content: str, + projects: list[dict], + config_data_types: list[str], +) -> tuple[str, list[str], str | None]: + fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None) + + if not file_content.strip(): + return fallback + + valid_project_ids = {p["id"] for p in projects} + + def _fmt_project(p: dict) -> str: + summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip() + summary_part = f" — {summary[:100]}" if summary else "" + return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}" + + projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)" + + domain_definitions = "\n".join( + f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}" + for d in config_data_types + if d in _DOMAIN_DESCRIPTIONS + ) + + system = _STEP1_SYSTEM_PROMPT.format( + domain_definitions=domain_definitions, + projects_list=projects_list, + ) + + llm = get_llm() + try: + response = await llm.ainvoke([ + SystemMessage(content=system), + HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"), + ]) + raw = _as_text(response.content).strip() + if raw.startswith("```"): + raw = raw.split("```")[1] + if raw.startswith("json"): + raw = raw[4:] + parsed = json.loads(raw.strip()) + raw_project_id: str = str(parsed.get("project_id") or "new") + project_id = raw_project_id if raw_project_id in valid_project_ids else "new" + new_project_name: str | None = ( + str(parsed["new_project_name"]).strip() or None + if project_id == "new" and parsed.get("new_project_name") + else None + ) + domains: list[str] = [ + d for d in parsed.get("domains", []) + if d in config_data_types + ] + if not domains: + domains = list(config_data_types) + return project_id, domains, new_project_name + except Exception as exc: + logger.warning( + "agent_runner: step1 classification failed for %r: %s", file_path, exc + ) + return fallback + + +# ── Local agent runner (two-step per file) ──────────────────────────────── + + +async def run_local_agent(user_id: str, trigger_data: dict[str, Any]) -> None: + """Execute a local directory agent run. + + In the microservice world, trigger_data is a serialized dict from + the REST route (forwarded via Redis), containing the agent config + fields and run_context. + + set_current_user() must be called BEFORE this function. + """ + run_context: dict = trigger_data.get("run_context", {}) + agent_id = run_context.get("agent_id", str(uuid.uuid4())) + run_id = run_context.get("run_id") + + _running_agents.add(agent_id) + + # Extract config from trigger payload + directory_paths: list[str] = trigger_data.get("directory_paths", []) + if not directory_paths: + directory = trigger_data.get("directory", "") + if directory: + directory_paths = [directory] + + data_types: list[str] = trigger_data.get("data_types", []) + file_extensions: list[str] = trigger_data.get("file_extensions", []) + prompt_template: str = trigger_data.get("prompt_template", "") + last_run_at_raw = trigger_data.get("last_run_at") + last_run_at: datetime | None = None + if last_run_at_raw: + if isinstance(last_run_at_raw, str): + last_run_at = datetime.fromisoformat(last_run_at_raw) + elif isinstance(last_run_at_raw, (int, float)): + last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc) + + errors: list[str] = [] + items_processed = 0 + items_created = 0 + + custom_section = ( + f"User instructions:\n{prompt_template}" + if prompt_template + else "" + ) + + # Create or load run log + run_log_id = run_id + if not run_log_id: + async with async_session() as db: + run_log = AgentRunLog( + agent_id=agent_id, + agent_type="local", + user_id=user_id, + status="running", + ) + db.add(run_log) + await db.commit() + await db.refresh(run_log) + run_log_id = run_log.id + + try: + # ── Scan directories ───────────────────────────────────────── + logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id) + file_paths = await _scan_directories( + paths=directory_paths, + extensions=file_extensions, + last_run_at=last_run_at, + ) + logger.info( + "agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths) + ) + + if not file_paths: + await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0) + return + + # ── Fetch all projects once ────────────────────────────────── + projects = await _fetch_projects() + + for file_path in file_paths: + try: + file_result = await execute_on_client( + action="read_file_content", data={"path": file_path} + ) + file_content: str = file_result.get("content", "") + if not file_content: + continue + + items_processed += 1 + + # Step 1 — classify file + project_id, domains, new_project_name = await _classify_file( + file_path=file_path, + file_content=file_content, + projects=projects, + config_data_types=data_types, + ) + + # Step 2 — resolve project_id, fetch entities, process + if project_id == "new": + proj_name = new_project_name or "Untitled Project" + try: + proj_result = await execute_on_client( + action="insert", + table="projects", + data={"name": proj_name, "clientId": None}, + ) + created = proj_result.get("row", {}) + effective_project_id = created.get("id", "standalone") + if "id" in created: + projects.append(created) + except Exception as exc: + logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc) + effective_project_id = "standalone" + proj_name = "unknown" + project_context = ( + f"Project: {proj_name} (id: {effective_project_id}). " + "Always set projectId to this id on every record you create." + ) + else: + effective_project_id = project_id + proj = next((p for p in projects if p["id"] == project_id), None) + proj_name = proj.get("name", project_id) if proj else project_id + project_context = ( + f"Project: {proj_name} (id: {project_id}). " + "Always set projectId to this id on every record you create." + ) + + domains = [d for d in domains if d != "projects"] + + existing_blocks: list[str] = [] + for domain in domains: + rows = await _fetch_domain_entities(domain, effective_project_id) + existing_blocks.append(_format_entities_for_context(domain, rows)) + + existing_context = "\n\n".join(existing_blocks) + + system_prompt = _PROCESSING_SYSTEM_PROMPT.format( + existing_context=existing_context, + project_context=project_context, + data_types=", ".join(domains), + custom_prompt_section=custom_section, + ) + + processing_tools = _build_processing_tools(domains) + + result_text = await _run_agent_with_tools( + system_prompt=system_prompt, + user_message=( + f"Process this file and extract relevant information.\n\n" + f"File: {file_path}\n\nContent:\n{file_content}" + ), + tools=processing_tools, + max_steps=_MAX_PROCESSING_STEPS, + ) + logger.info( + "agent_runner: run=%s file=%r result=%s", + run_log_id, file_path, result_text[:200], + ) + + except Exception as exc: + errors.append(f"Error processing '{file_path}': {exc}") + logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc) + + except Exception as exc: + errors.append(f"Agent run failed: {exc}") + logger.error("agent_runner: run=%s failed: %s", run_log_id, exc) + finally: + _running_agents.discard(agent_id) + + # ── Finalise ──────────────────────────────────────────────────── + if errors and items_processed == 0: + final_status = "error" + elif errors: + final_status = "partial" + else: + final_status = "success" + + await _finalize_run( + run_log_id, + status=final_status, + items_processed=items_processed, + items_created=items_created, + errors=errors, + ) + + # Notify Electron that the run is complete via Redis + if run_context: + try: + channel = ws_out_channel(user_id) + await redis_client.publish(channel, json.dumps({ + "type": "run_complete", + "run_context": run_context, + "status": final_status, + })) + except Exception as exc: + logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc) + + +# ── Cloud agent runner ───────────────────────────────────────────────────── + +_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7 + + +async def run_cloud_agent(user_id: str, config_id: str) -> None: + """Execute a cloud connector agent run. + + Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches + messages from the provider, and runs LLM extraction. + + set_current_user() must be called BEFORE this function. + """ + from app.integrations import decrypt_token, encrypt_token, get_provider + + async with async_session() as db: + result = await db.execute( + select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) + ) + config = result.scalar_one_or_none() + if config is None: + logger.error("agent_runner: cloud config %s not found", config_id) + return + + # Create run log + run_log = AgentRunLog( + agent_id=config.id, + agent_type="cloud", + user_id=user_id, + status="running", + ) + db.add(run_log) + await db.commit() + await db.refresh(run_log) + run_log_id = run_log.id + + # ── Decrypt OAuth token ──────────────────────────────────────── + if not config.oauth_token_encrypted: + await _finalize_run( + run_log_id, + status="error", + errors=[f"No OAuth token stored for cloud agent '{config.name}'"], + ) + return + + try: + credentials_info = decrypt_token(config.oauth_token_encrypted) + except ValueError as exc: + await _finalize_run( + run_log_id, + status="error", + errors=[f"Failed to decrypt OAuth token: {exc}"], + ) + return + + # ── Instantiate provider ────────────────────────────────────── + try: + provider = get_provider(config.provider, credentials_info) + except ValueError as exc: + await _finalize_run(run_log_id, status="error", errors=[str(exc)]) + return + + # ── Fetch messages ──────────────────────────────────────────── + since: datetime | None = config.last_run_at + if since is None: + since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS) + if since.tzinfo is None: + since = since.replace(tzinfo=timezone.utc) + + errors: list[str] = [] + items_processed = 0 + + try: + if config.provider == "gmail": + raw_messages = await provider.fetch_messages( + filter_config=config.filter_config, + since=since, + ) + elif config.provider == "outlook": + raw_messages = await provider.fetch_emails( + filter_config=config.filter_config, + since=since, + ) + elif config.provider == "teams": + raw_messages = await provider.fetch_messages( + filter_config=config.filter_config, + since=since, + ) + else: + raw_messages = [] + except RuntimeError as exc: + await _finalize_run( + run_log_id, + status="error", + errors=[f"Provider fetch failed: {exc}"], + update_config_last_run=True, + config_id=config.id, + config_type="cloud", + ) + return + + logger.info( + "agent_runner: cloud agent %s fetched %d item(s) from %s", + config.id, len(raw_messages), config.provider, + ) + + # ── Extract + insert via LLM ───────────────────────────────── + try: + processing_tools = _build_processing_tools(config.data_types) + custom_section = ( + f"User instructions:\n{config.prompt_template}" + if config.prompt_template + else "" + ) + + for msg in raw_messages: + content_text = msg.as_text + if not content_text: + continue + items_processed += 1 + + processing_prompt = _CLOUD_PROCESSING_PROMPT.format( + data_types=", ".join(config.data_types), + project_context="Determine the appropriate project from the message context.", + file_list=f"Message from {config.provider} (id: {msg.id})", + custom_prompt_section=custom_section, + ) + + try: + await _run_agent_with_tools( + system_prompt=processing_prompt, + user_message=f"Process this message content:\n\n{content_text[:8000]}", + tools=processing_tools, + max_steps=_MAX_PROCESSING_STEPS, + ) + except Exception as exc: + errors.append(f"LLM processing error for message {msg.id!r}: {exc}") + except Exception as exc: + errors.append(f"Agent run failed: {exc}") + + # ── Persist refreshed token ─────────────────────────────────── + refreshed = getattr(provider, "refreshed_credentials", None) + if refreshed: + try: + new_encrypted = encrypt_token(refreshed) + async with async_session() as db: + cfg_result = await db.execute( + select(CloudAgentConfig).where(CloudAgentConfig.id == config.id) + ) + cfg_row = cfg_result.scalar_one_or_none() + if cfg_row: + cfg_row.oauth_token_encrypted = new_encrypted + await db.commit() + except Exception as exc: + logger.warning("agent_runner: failed to persist refreshed token: %s", exc) + + # ── Finalise ────────────────────────────────────────────────── + if errors and items_processed == 0: + final_status = "error" + elif errors: + final_status = "partial" + else: + final_status = "success" + + await _finalize_run( + run_log_id, + status=final_status, + items_processed=items_processed, + items_created=0, + errors=errors, + update_config_last_run=True, + config_id=config.id, + config_type="cloud", + ) + + +# ── Internal helper ───────────────────────────────────────────────────────── + + +async def _finalize_run( + run_log_id: int | str, + *, + status: str, + items_processed: int = 0, + items_created: int = 0, + errors: list[str] | None = None, + update_config_last_run: bool = False, + config_id: str | None = None, + config_type: str | None = None, +) -> None: + """Persist the run outcome and optionally update last_run_at on the config.""" + now = datetime.now(timezone.utc) + try: + async with async_session() as db: + result = await db.execute( + select(AgentRunLog).where(AgentRunLog.id == run_log_id) + ) + managed = result.scalar_one_or_none() + if managed is None: + logger.warning("agent_runner: run_log %s not found for finalization", run_log_id) + return + + managed.status = status + managed.items_processed = items_processed + managed.items_created = items_created + managed.errors = errors or [] + managed.completed_at = now + + if update_config_last_run and config_id: + if config_type == "local": + cfg_result = await db.execute( + select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) + ) + cfg = cfg_result.scalar_one_or_none() + if cfg: + cfg.last_run_at = now + elif config_type == "cloud": + cfg_result = await db.execute( + select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) + ) + cfg = cfg_result.scalar_one_or_none() + if cfg: + cfg.last_run_at = now + + await db.commit() + except Exception as exc: + logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc) diff --git a/services/batch-agent/app/agents/__init__.py b/services/batch-agent/app/agents/__init__.py new file mode 100644 index 0000000..50e7414 --- /dev/null +++ b/services/batch-agent/app/agents/__init__.py @@ -0,0 +1 @@ +"""Batch Agent Service domain agents and filesystem tools.""" diff --git a/services/batch-agent/app/agents/filesystem_agent.py b/services/batch-agent/app/agents/filesystem_agent.py new file mode 100644 index 0000000..921caaa --- /dev/null +++ b/services/batch-agent/app/agents/filesystem_agent.py @@ -0,0 +1,83 @@ +"""Filesystem agent — tools for reading local directories and files on Electron. + +Adapted for Batch Agent Service: import from app.ws_context. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.tools import tool + +from app.ws_context import execute_on_client + + +@tool +async def list_directory(path: str) -> str: + """List files and folders in a local directory on the user's device. + + Returns a formatted listing of entries with name, type (file/directory), + and full path. + """ + result = await execute_on_client( + action="list_directory", + data={"path": path}, + ) + entries: list[dict[str, Any]] = result.get("entries", []) + if not entries: + return f"Directory '{path}' is empty or does not exist." + lines: list[str] = [] + for entry in entries: + entry_type = entry.get("type", "unknown") + entry_name = entry.get("name", "") + entry_path = entry.get("path", "") + lines.append(f"- [{entry_type}] {entry_name} ({entry_path})") + return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines) + + +@tool +async def read_file_content(path: str) -> str: + """Read the text content of a local file on the user's device. + + Returns the file content as a string. Large files may be truncated + by the Electron client. + """ + result = await execute_on_client( + action="read_file_content", + data={"path": path}, + ) + content: str = result.get("content", "") + if not content: + return f"File '{path}' is empty or could not be read." + return content + + +@tool +async def get_file_metadata(path: str) -> str: + """Get metadata for a local file: size, creation date, modification date, extension. + + Returns a formatted summary of the file's metadata. + """ + result = await execute_on_client( + action="get_file_metadata", + data={"path": path}, + ) + size = result.get("size", "unknown") + created = result.get("createdAt", "unknown") + modified = result.get("modifiedAt", "unknown") + extension = result.get("extension", "unknown") + name = result.get("name", path) + return ( + f"File: {name}\n" + f" Extension: {extension}\n" + f" Size: {size} bytes\n" + f" Created: {created}\n" + f" Modified: {modified}" + ) + + +FILESYSTEM_TOOLS: list[Any] = [ + list_directory, + read_file_content, + get_file_metadata, +] diff --git a/services/batch-agent/app/agents/note_agent.py b/services/batch-agent/app/agents/note_agent.py new file mode 100644 index 0000000..7b48046 --- /dev/null +++ b/services/batch-agent/app/agents/note_agent.py @@ -0,0 +1,110 @@ +"""Note agent — Markdown note management. + +Adapted for Batch Agent Service: import from app.ws_context and app.llm. +""" + +from __future__ import annotations + +import re +from typing import Any + +from langchain_core.tools import tool + +from app.llm import embed +from app.ws_context import execute_on_client + +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" +) + + +def _is_uuid(value: str) -> bool: + return bool(_UUID_RE.match(value)) + + +@tool +async def list_notes(project_id: str = "") -> str: + """List notes, optionally scoped to a project by project_id.""" + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + result = await execute_on_client( + action="select", + table="notes", + filters={"projectId": normalized_project_id or None}, + ) + rows = result.get("rows", []) + if not rows: + return "No notes found." + lines = [f"- {r['title']} (id: {r['id']})" for r in rows] + return f"Found {len(rows)} note(s):\n" + "\n".join(lines) + + +@tool +async def get_note(note_id: str) -> str: + """Fetch a single note by its UUID to read its full Markdown content.""" + result = await execute_on_client(action="get", table="notes", data={"id": note_id}) + row = result.get("row") + if not row: + return f"Note {note_id} not found." + return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}" + + +@tool +async def create_note(title: str, content: str, project_id: str = "") -> str: + """Create a new note.""" + result = await execute_on_client( + action="insert", + table="notes", + data={ + "title": title, + "content": content, + "projectId": project_id or None, + }, + ) + row = result["row"] + vector = await embed(content) + await execute_on_client( + action="vector_upsert", + data={"id": row["id"], "projectId": row.get("projectId"), "content": content}, + vector=vector, + ) + return f"Note created: '{row['title']}' (id: {row['id']})." + + +@tool +async def update_note(note_id: str, title: str = "", content: str = "") -> str: + """Update an existing note. Only pass fields that should change.""" + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if content: + updates["content"] = content + result = await execute_on_client( + action="update", + table="notes", + data={"id": note_id, "updates": updates}, + ) + row = result["row"] + if content: + vector = await embed(content) + await execute_on_client( + action="vector_upsert", + data={"id": note_id, "projectId": row.get("projectId"), "content": content}, + vector=vector, + ) + return f"Note updated: '{row['title']}' (id: {row['id']})." + + +@tool +async def delete_note(note_id: str) -> str: + """Delete a note permanently by its UUID.""" + await execute_on_client(action="delete", table="notes", data={"id": note_id}) + return f"Note {note_id} deleted." + + +NOTE_TOOLS: list[Any] = [ + list_notes, + get_note, + create_note, + update_note, + delete_note, +] diff --git a/services/batch-agent/app/agents/project_agent.py b/services/batch-agent/app/agents/project_agent.py new file mode 100644 index 0000000..2d30eaf --- /dev/null +++ b/services/batch-agent/app/agents/project_agent.py @@ -0,0 +1,110 @@ +"""Project agent — full lifecycle management. + +Adapted for Batch Agent Service: import from app.ws_context. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.tools import tool + +from app.ws_context import execute_on_client + + +@tool +async def list_projects(client_id: str = "", include_archived: int = 0) -> str: + """List projects, optionally filtered by client_id.""" + result = await execute_on_client( + action="select", + table="projects", + filters={ + "clientId": client_id or None, + "includeArchived": bool(include_archived), + }, + ) + rows = result.get("rows", []) + if not rows: + return "No projects found." + lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows] + return f"Found {len(rows)} project(s):\n" + "\n".join(lines) + + +@tool +async def list_all_projects() -> str: + """List every project regardless of client or status.""" + result = await execute_on_client(action="select", table="projects") + rows = result.get("rows", []) + if not rows: + return "No projects found." + lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows] + return f"All projects ({len(rows)}):\n" + "\n".join(lines) + + +@tool +async def get_project(project_id: str) -> str: + """Fetch a single project by its UUID.""" + result = await execute_on_client(action="get", table="projects", data={"id": project_id}) + row = result.get("row") + if not row: + return f"Project {project_id} not found." + return ( + f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, " + f"clientId: {row.get('clientId', 'none')})" + ) + + +@tool +async def create_project(name: str, client_id: str = "") -> str: + """Create a new project.""" + result = await execute_on_client( + action="insert", + table="projects", + data={"name": name, "clientId": client_id or None}, + ) + row = result["row"] + return f"Project created: '{row['name']}' (id: {row['id']})" + + +@tool +async def update_project( + project_id: str, + name: str = "", + client_id: str = "", + status: str = "", + ai_summary: str = "", +) -> str: + """Update a project. Only pass fields that should change.""" + updates: dict[str, Any] = {} + if name: + updates["name"] = name + if client_id: + updates["clientId"] = client_id + if status: + updates["status"] = status + if ai_summary: + updates["aiSummary"] = ai_summary + result = await execute_on_client( + action="update", + table="projects", + data={"id": project_id, "updates": updates}, + ) + row = result["row"] + return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})" + + +@tool +async def delete_project(project_id: str) -> str: + """Permanently delete a project.""" + await execute_on_client(action="delete", table="projects", data={"id": project_id}) + return f"Project {project_id} permanently deleted." + + +PROJECT_TOOLS: list[Any] = [ + list_projects, + list_all_projects, + get_project, + create_project, + update_project, + delete_project, +] diff --git a/services/batch-agent/app/agents/task_agent.py b/services/batch-agent/app/agents/task_agent.py new file mode 100644 index 0000000..5e8753d --- /dev/null +++ b/services/batch-agent/app/agents/task_agent.py @@ -0,0 +1,197 @@ +"""Task agent — full CRUD for tasks and task comments. + +Adapted for Batch Agent Service: import from app.ws_context. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +import re +from typing import Any + +from langchain_core.tools import tool + +from app.ws_context import execute_on_client + +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" +) + + +def _is_uuid(value: str) -> bool: + return bool(_UUID_RE.match(value)) + + +@tool +async def list_tasks( + project_id: str = "", + status: str = "", + search: str = "", + order_by: str = "", +) -> str: + """List tasks, optionally filtered by project_id, status, search, or order_by.""" + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + result = await execute_on_client( + action="select", + table="tasks", + filters={ + "projectId": normalized_project_id or None, + "status": status or None, + "search": search or None, + "orderBy": order_by or None, + }, + ) + rows = result.get("rows", []) + if not rows: + return "No tasks found matching the given filters." + lines = [ + f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" + for r in rows + ] + return f"Found {len(rows)} task(s):\n" + "\n".join(lines) + + +@tool +async def create_task( + title: str, + description: str = "", + status: str = "todo", + priority: str = "medium", + assignees: str = "[]", + due_date: int = 0, + project_id: str = "", + is_ai_suggested: int = 0, +) -> str: + """Create a new task.""" + result = await execute_on_client( + action="insert", + table="tasks", + data={ + "title": title, + "description": description or None, + "status": status, + "priority": priority, + "assignee": assignees, + "dueDate": due_date or None, + "projectId": project_id or None, + "isAiSuggested": is_ai_suggested, + }, + ) + row = result["row"] + return ( + f"Task created: '{row['title']}' " + f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})" + ) + + +@tool +async def update_task( + task_id: str, + title: str = "", + description: str = "", + status: str = "", + priority: str = "", + assignees: str = "", + due_date: int = -1, + project_id: str = "", +) -> str: + """Update fields on an existing task. Only pass fields you want to change.""" + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if description: + updates["description"] = description + if status: + updates["status"] = status + if priority: + updates["priority"] = priority + if assignees: + updates["assignee"] = assignees + if due_date != -1: + updates["dueDate"] = due_date or None + if project_id: + updates["projectId"] = project_id + result = await execute_on_client( + action="update", + table="tasks", + data={"id": task_id, "updates": updates}, + ) + row = result["row"] + return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})" + + +@tool +async def delete_task(task_id: str) -> str: + """Delete a task permanently by its UUID.""" + await execute_on_client(action="delete", table="tasks", data={"id": task_id}) + return f"Task {task_id} deleted." + + +@tool +async def list_tasks_due_today() -> str: + """List all tasks whose due date falls on today's date.""" + now = datetime.now(tz=timezone.utc) + start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000) + end_ms = start_ms + 86_400_000 - 1 + result = await execute_on_client( + action="select", + table="tasks", + filters={"dueDateFrom": start_ms, "dueDateTo": end_ms}, + ) + rows = result.get("rows", []) + if not rows: + return "No tasks are due today." + lines = [ + f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})" + for r in rows + ] + return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines) + + +@tool +async def list_task_comments(task_id: str) -> str: + """List all comments on a task by its UUID.""" + result = await execute_on_client( + action="select", + table="taskComments", + filters={"taskId": task_id}, + ) + rows = result.get("rows", []) + if not rows: + return f"No comments found for task {task_id}." + lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows] + return f"Found {len(rows)} comment(s):\n" + "\n".join(lines) + + +@tool +async def add_task_comment(task_id: str, author: str, content: str) -> str: + """Add a comment to a task.""" + result = await execute_on_client( + action="insert", + table="taskComments", + data={"taskId": task_id, "author": author, "content": content}, + ) + row = result.get("row", {}) + row_author = row.get("author", author) + row_task_id = row.get("taskId") or row.get("task_id") or task_id + row_comment_id = row.get("id", "unknown") + return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})." + + +@tool +async def delete_task_comment(comment_id: str) -> str: + """Delete a task comment by its UUID.""" + await execute_on_client(action="delete", table="taskComments", data={"id": comment_id}) + return f"Comment {comment_id} deleted." + + +TASK_TOOLS: list[Any] = [ + list_tasks, + create_task, + update_task, + delete_task, + list_tasks_due_today, + list_task_comments, + add_task_comment, + delete_task_comment, +] diff --git a/services/batch-agent/app/agents/timeline_agent.py b/services/batch-agent/app/agents/timeline_agent.py new file mode 100644 index 0000000..1e54582 --- /dev/null +++ b/services/batch-agent/app/agents/timeline_agent.py @@ -0,0 +1,88 @@ +"""Timeline agent — project milestone management. + +Adapted for Batch Agent Service: import from app.ws_context. +""" + +from __future__ import annotations + +import re +from typing import Any + +from langchain_core.tools import tool + +from app.ws_context import execute_on_client + +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" +) + + +def _is_uuid(value: str) -> bool: + return bool(_UUID_RE.match(value)) + + +@tool +async def list_timelines(project_id: str = "") -> str: + """List timelines. Provide project_id to scope to a specific project.""" + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + result = await execute_on_client( + action="select", + table="timelines", + filters={"projectId": normalized_project_id or None}, + ) + rows = result.get("rows", []) + if not rows: + return "No timelines found." + lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows] + return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines) + + +@tool +async def create_timeline( + project_id: str, title: str, date: int, is_ai_suggested: int = 0, +) -> str: + """Create a project timeline (milestone).""" + result = await execute_on_client( + action="insert", + table="timelines", + data={ + "projectId": project_id, + "title": title, + "date": date, + "isAiSuggested": is_ai_suggested, + }, + ) + row = result["row"] + return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})" + + +@tool +async def update_timeline(timeline_id: str, title: str = "", date: int = -1) -> str: + """Update a timeline. Only pass fields that should change.""" + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if date != -1: + updates["date"] = date + result = await execute_on_client( + action="update", + table="timelines", + data={"id": timeline_id, "updates": updates}, + ) + row = result["row"] + return f"Timeline updated: '{row['title']}' (id: {row['id']})" + + +@tool +async def delete_timeline(timeline_id: str) -> str: + """Delete a timeline permanently by its UUID.""" + await execute_on_client(action="delete", table="timelines", data={"id": timeline_id}) + return f"Timeline {timeline_id} deleted." + + +TIMELINE_TOOLS: list[Any] = [ + list_timelines, + create_timeline, + update_timeline, + delete_timeline, +] diff --git a/services/batch-agent/app/integrations/__init__.py b/services/batch-agent/app/integrations/__init__.py new file mode 100644 index 0000000..0fb83db --- /dev/null +++ b/services/batch-agent/app/integrations/__init__.py @@ -0,0 +1,108 @@ +"""Cloud provider integration utilities. + +Adapted for Batch Agent Service: import from shared.config instead of app.config. + +Provides: + * Shared message dataclasses (EmailMessage, ChatMessage) + * get_provider() — factory for Gmail/MS Graph clients + * encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import TYPE_CHECKING + +from cryptography.fernet import Fernet, InvalidToken + +from shared.config import settings + +if TYPE_CHECKING: + from app.integrations.gmail import GmailClient + from app.integrations.ms_graph import MSGraphClient + +logger = logging.getLogger(__name__) + + +@dataclass +class EmailMessage: + id: str + subject: str + sender: str + body_text: str + date: datetime + labels: list[str] = field(default_factory=list) + + @property + def as_text(self) -> str: + date_str = self.date.strftime("%Y-%m-%d %H:%M") + labels_str = f" [{', '.join(self.labels)}]" if self.labels else "" + return ( + f"From: {self.sender}\n" + f"Date: {date_str}{labels_str}\n" + f"Subject: {self.subject}\n\n" + f"{self.body_text}" + ) + + +@dataclass +class ChatMessage: + id: str + content: str + sender: str + channel: str | None + date: datetime + + @property + def as_text(self) -> str: + date_str = self.date.strftime("%Y-%m-%d %H:%M") + channel_str = f" [channel: {self.channel}]" if self.channel else "" + return ( + f"From: {self.sender}\n" + f"Date: {date_str}{channel_str}\n\n" + f"{self.content}" + ) + + +def _get_fernet() -> Fernet: + key = settings.OAUTH_ENCRYPTION_KEY + if not key: + raise RuntimeError( + "OAUTH_ENCRYPTION_KEY is not set. " + "Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\"" + ) + return Fernet(key.encode() if isinstance(key, str) else key) + + +def encrypt_token(token_info: dict) -> str: + if not isinstance(token_info, dict) or not token_info: + raise ValueError("token_info must be a non-empty dict") + plaintext = json.dumps(token_info).encode("utf-8") + return _get_fernet().encrypt(plaintext).decode("utf-8") + + +def decrypt_token(encrypted: str) -> dict: + try: + plaintext = _get_fernet().decrypt(encrypted.encode("utf-8")) + return json.loads(plaintext) + except (InvalidToken, json.JSONDecodeError) as exc: + raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc + + +def get_provider( + provider: str, + credentials_info: dict, +) -> "GmailClient | MSGraphClient": + if provider == "gmail": + from app.integrations.gmail import GmailClient + return GmailClient(credentials_info) + if provider in {"outlook", "teams"}: + from app.integrations.ms_graph import MSGraphClient + return MSGraphClient(credentials_info) + raise ValueError( + f"Unknown cloud provider {provider!r}. " + "Supported: 'gmail', 'outlook', 'teams'." + ) diff --git a/services/batch-agent/app/integrations/gmail.py b/services/batch-agent/app/integrations/gmail.py new file mode 100644 index 0000000..da73c4f --- /dev/null +++ b/services/batch-agent/app/integrations/gmail.py @@ -0,0 +1,252 @@ +"""Gmail API client for cloud agent integration. + +Adapted for Batch Agent Service: import from app.integrations instead of +app.integrations (same relative path within the service). +""" + +from __future__ import annotations + +import asyncio +import base64 +import email +import html +import logging +import re +from datetime import datetime, timezone +from typing import Any + +from app.integrations import EmailMessage + +logger = logging.getLogger(__name__) + +_GMAIL_DATE_FMT = "%Y/%m/%d" +_BODY_TRUNCATE = 8_000 +_MAX_MESSAGES = 200 + + +def _build_gmail_query( + filter_config: dict[str, Any] | None, + since: datetime | None, +) -> str: + parts: list[str] = [] + cfg = filter_config or {} + + labels: list[str] = cfg.get("labels", []) + if labels: + if len(labels) == 1: + parts.append(f"label:{labels[0]}") + else: + label_expr = " OR ".join(f"label:{lbl}" for lbl in labels) + parts.append(f"({label_expr})") + + senders: list[str] = cfg.get("senders", []) + for sender in senders: + parts.append(f"from:{sender}") + + date_range: dict = cfg.get("date_range", {}) + from_str: str | None = date_range.get("from") + to_str: str | None = date_range.get("to") + + effective_since: datetime | None = since + if from_str: + try: + cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00")) + if cfg_since.tzinfo is None: + cfg_since = cfg_since.replace(tzinfo=timezone.utc) + if effective_since is None or cfg_since > effective_since: + effective_since = cfg_since + except ValueError: + logger.warning("gmail: invalid date_range.from %r — ignoring", from_str) + + if effective_since: + parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}") + + if to_str: + try: + to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00")) + parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}") + except ValueError: + logger.warning("gmail: invalid date_range.to %r — ignoring", to_str) + + return " ".join(parts) + + +def _strip_html(raw_html: str) -> str: + no_tags = re.sub(r"<[^>]+>", " ", raw_html) + decoded = html.unescape(no_tags) + return re.sub(r"\s+", " ", decoded).strip() + + +def _parse_body(payload: dict[str, Any]) -> str: + mime_type: str = payload.get("mimeType", "") + body: dict = payload.get("body", {}) + parts: list[dict] = payload.get("parts", []) + + if mime_type == "text/plain": + data = body.get("data", "") + if data: + return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace") + return "" + + if mime_type == "text/html": + data = body.get("data", "") + if data: + raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace") + return _strip_html(raw) + return "" + + plain_fallback = "" + for part in parts: + part_mime = part.get("mimeType", "") + if part_mime == "text/plain": + return _parse_body(part) + if part_mime == "text/html" and not plain_fallback: + plain_fallback = _parse_body(part) + if part_mime.startswith("multipart/"): + nested = _parse_body(part) + if nested: + return nested + return plain_fallback + + +def _parse_date(raw: str) -> datetime: + try: + parsed = email.utils.parsedate_to_datetime(raw) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + except Exception: + return datetime.now(timezone.utc) + + +class GmailClient: + def __init__(self, credentials_info: dict[str, Any]) -> None: + from google.oauth2.credentials import Credentials + + self._credentials_info = credentials_info + expiry_str: str | None = credentials_info.get("expiry") + expiry: datetime | None = None + if expiry_str: + try: + expiry = datetime.fromisoformat( + expiry_str.replace("Z", "+00:00") + ).replace(tzinfo=timezone.utc) + except ValueError: + pass + + self._credentials = Credentials( + token=credentials_info.get("token"), + refresh_token=credentials_info.get("refresh_token"), + token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"), + client_id=credentials_info.get("client_id"), + client_secret=credentials_info.get("client_secret"), + scopes=credentials_info.get("scopes"), + expiry=expiry, + ) + + async def fetch_messages( + self, + filter_config: dict[str, Any] | None = None, + since: datetime | None = None, + ) -> list[EmailMessage]: + query = _build_gmail_query(filter_config, since) + logger.debug("gmail: executing search query %r", query) + return await asyncio.to_thread(self._fetch_sync, query) + + @property + def refreshed_credentials(self) -> dict[str, Any] | None: + creds = self._credentials + if not creds.valid and creds.expired: + return None + if creds.token != self._credentials_info.get("token"): + result = { + "token": creds.token, + "refresh_token": creds.refresh_token, + "token_uri": creds.token_uri, + "client_id": creds.client_id, + "client_secret": creds.client_secret, + "scopes": list(creds.scopes or []), + } + if creds.expiry: + result["expiry"] = creds.expiry.isoformat() + return result + return None + + def _fetch_sync(self, query: str) -> list[EmailMessage]: + import googleapiclient.discovery + import googleapiclient.errors + from google.auth.transport.requests import Request + + if self._credentials.expired and self._credentials.refresh_token: + try: + self._credentials.refresh(Request()) + except Exception as exc: + raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc + + service = googleapiclient.discovery.build( + "gmail", "v1", credentials=self._credentials, cache_discovery=False + ) + user_api = service.users() + + ids: list[str] = [] + page_token: str | None = None + while len(ids) < _MAX_MESSAGES: + batch_size = min(100, _MAX_MESSAGES - len(ids)) + kwargs: dict[str, Any] = { + "userId": "me", + "maxResults": batch_size, + } + if query: + kwargs["q"] = query + if page_token: + kwargs["pageToken"] = page_token + + try: + resp = user_api.messages().list(**kwargs).execute() + except googleapiclient.errors.HttpError as exc: + raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc + + for msg in resp.get("messages", []): + ids.append(msg["id"]) + + page_token = resp.get("nextPageToken") + if not page_token: + break + + if not ids: + return [] + + logger.info("gmail: fetching %d message(s)", len(ids)) + + messages: list[EmailMessage] = [] + for msg_id in ids: + try: + msg = user_api.messages().get( + userId="me", id=msg_id, format="full" + ).execute() + + headers: dict[str, str] = { + h["name"].lower(): h["value"] + for h in msg.get("payload", {}).get("headers", []) + } + subject = headers.get("subject", "(no subject)") + sender = headers.get("from", "unknown") + date_raw = headers.get("date", "") + date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc) + + body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE] + labels = msg.get("labelIds", []) + + messages.append(EmailMessage( + id=msg_id, + subject=subject, + sender=sender, + body_text=body_text, + date=date, + labels=labels, + )) + except Exception as exc: + logger.warning("gmail: skipping message %s: %s", msg_id, exc) + + logger.info("gmail: returned %d message(s)", len(messages)) + return messages diff --git a/services/batch-agent/app/integrations/ms_graph.py b/services/batch-agent/app/integrations/ms_graph.py new file mode 100644 index 0000000..15a29e9 --- /dev/null +++ b/services/batch-agent/app/integrations/ms_graph.py @@ -0,0 +1,266 @@ +"""Microsoft Graph API client for Outlook and Teams. + +Adapted for Batch Agent Service: import settings from shared.config. +""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime, timezone +from typing import Any + +import httpx + +from shared.config import settings +from app.integrations import ChatMessage, EmailMessage + +logger = logging.getLogger(__name__) + +_GRAPH_BASE = "https://graph.microsoft.com/v1.0" + +_MAX_EMAILS = 200 +_MAX_MESSAGES = 200 +_BODY_TRUNCATE = 8_000 + + +def _strip_html(raw: str) -> str: + no_tags = re.sub(r"<[^>]+>", " ", raw) + import html as _html + decoded = _html.unescape(no_tags) + return re.sub(r"\s+", " ", decoded).strip() + + +def _odata_datetime(dt: datetime) -> str: + utc = dt.astimezone(timezone.utc) + return utc.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def _build_email_filter( + filter_config: dict[str, Any] | None, + since: datetime | None, +) -> str: + clauses: list[str] = [] + cfg = filter_config or {} + + senders: list[str] = cfg.get("senders", []) + if senders: + sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders] + clauses.append("(" + " or ".join(sender_clauses) + ")") + + date_range: dict = cfg.get("date_range", {}) + from_str: str | None = date_range.get("from") + + effective_since: datetime | None = since + if from_str: + try: + cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00")) + if cfg_since.tzinfo is None: + cfg_since = cfg_since.replace(tzinfo=timezone.utc) + if effective_since is None or cfg_since > effective_since: + effective_since = cfg_since + except ValueError: + logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str) + + if effective_since: + clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}") + + to_str: str | None = date_range.get("to") + if to_str: + try: + to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00")) + if to_dt.tzinfo is None: + to_dt = to_dt.replace(tzinfo=timezone.utc) + clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}") + except ValueError: + logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str) + + return " and ".join(clauses) + + +class MSGraphClient: + def __init__(self, credentials_info: dict[str, Any]) -> None: + self._credentials_info = credentials_info + self._access_token: str = credentials_info.get("access_token", "") + self._original_access_token: str = self._access_token + self._refresh_token: str | None = credentials_info.get("refresh_token") + + def _auth_headers(self) -> dict[str, str]: + return {"Authorization": f"Bearer {self._access_token}"} + + async def _refresh_access_token(self) -> None: + import msal + + app = msal.ConfidentialClientApplication( + client_id=settings.MS_CLIENT_ID, + client_credential=settings.MS_CLIENT_SECRET, + authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}", + ) + scopes: list[str] = self._credentials_info.get("scope", "").split() + if not scopes: + scopes = ["https://graph.microsoft.com/.default"] + + result = app.acquire_token_by_refresh_token( + self._refresh_token, + scopes=scopes, + ) + if "access_token" not in result: + error = result.get("error_description", result.get("error", "unknown")) + raise RuntimeError(f"MS Graph token refresh failed: {error}") + + self._access_token = result["access_token"] + if "refresh_token" in result: + self._refresh_token = result["refresh_token"] + self._credentials_info["refresh_token"] = result["refresh_token"] + self._credentials_info["access_token"] = self._access_token + + @property + def refreshed_credentials(self) -> dict[str, Any] | None: + if self._access_token != self._original_access_token: + return {**self._credentials_info, "access_token": self._access_token} + return None + + async def _get( + self, + client: httpx.AsyncClient, + url: str, + params: dict[str, Any] | None = None, + *, + retry_on_401: bool = True, + ) -> dict[str, Any]: + resp = await client.get(url, params=params, headers=self._auth_headers()) + if resp.status_code == 401 and retry_on_401 and self._refresh_token: + await self._refresh_access_token() + resp = await client.get(url, params=params, headers=self._auth_headers()) + if resp.status_code == 429: + raise RuntimeError("MS Graph rate limit hit (429). Try again later.") + resp.raise_for_status() + return resp.json() + + async def fetch_emails( + self, + filter_config: dict[str, Any] | None = None, + since: datetime | None = None, + ) -> list[EmailMessage]: + odata_filter = _build_email_filter(filter_config, since) + params: dict[str, Any] = { + "$top": 50, + "$select": "id,subject,from,receivedDateTime,body,bodyPreview", + "$orderby": "receivedDateTime desc", + } + if odata_filter: + params["$filter"] = odata_filter + + emails: list[EmailMessage] = [] + url = f"{_GRAPH_BASE}/me/messages" + + async with httpx.AsyncClient(timeout=30.0) as client: + while url and len(emails) < _MAX_EMAILS: + data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None) + for item in data.get("value", []): + emails.append(self._parse_email(item)) + if len(emails) >= _MAX_EMAILS: + break + url = data.get("@odata.nextLink", "") + params = {} + + logger.info("ms_graph: fetched %d Outlook email(s)", len(emails)) + return emails + + async def fetch_messages( + self, + filter_config: dict[str, Any] | None = None, + since: datetime | None = None, + ) -> list[ChatMessage]: + cfg = filter_config or {} + channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])] + params: dict[str, Any] = {"$top": 50} + if since: + params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}" + + messages: list[ChatMessage] = [] + url = f"{_GRAPH_BASE}/me/chats/getAllMessages" + + async with httpx.AsyncClient(timeout=30.0) as client: + while url and len(messages) < _MAX_MESSAGES: + try: + data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None) + except httpx.HTTPStatusError as exc: + if exc.response.status_code in (403, 404): + logger.warning( + "ms_graph: /me/chats/getAllMessages not available (%d)", + exc.response.status_code, + ) + break + raise + + for item in data.get("value", []): + msg = self._parse_teams_message(item) + if channel_filter and msg.channel: + if not any(c in msg.channel.lower() for c in channel_filter): + continue + messages.append(msg) + if len(messages) >= _MAX_MESSAGES: + break + url = data.get("@odata.nextLink", "") + params = {} + + logger.info("ms_graph: fetched %d Teams message(s)", len(messages)) + return messages + + @staticmethod + def _parse_email(item: dict[str, Any]) -> EmailMessage: + subject: str = item.get("subject", "(no subject)") or "(no subject)" + sender_block = item.get("from", {}) or {} + sender_addr = ( + (sender_block.get("emailAddress") or {}).get("address", "unknown") + ) + date_str: str = item.get("receivedDateTime", "") + try: + date = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except Exception: + date = datetime.now(timezone.utc) + + body_block = item.get("body", {}) or {} + content_type: str = body_block.get("contentType", "text") + raw_body: str = body_block.get("content", "") + if content_type == "html": + body_text = _strip_html(raw_body) + else: + body_text = raw_body or item.get("bodyPreview", "") + body_text = body_text[:_BODY_TRUNCATE] + + return EmailMessage( + id=item.get("id", ""), + subject=subject, + sender=sender_addr, + body_text=body_text, + date=date, + ) + + @staticmethod + def _parse_teams_message(item: dict[str, Any]) -> ChatMessage: + msg_id: str = item.get("id", "") + sender_block = (item.get("from") or {}).get("user") or {} + sender: str = sender_block.get("displayName", "unknown") + channel: str | None = (item.get("channelIdentity") or {}).get("channelId") + + date_str: str = item.get("createdDateTime", "") + try: + date = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except Exception: + date = datetime.now(timezone.utc) + + body_block = item.get("body", {}) or {} + content_type: str = body_block.get("contentType", "text") + raw_content: str = body_block.get("content", "") + content = _strip_html(raw_content) if content_type == "html" else raw_content + content = content[:_BODY_TRUNCATE] + + return ChatMessage( + id=msg_id, + content=content, + sender=sender, + channel=channel, + date=date, + ) diff --git a/services/batch-agent/app/journey.py b/services/batch-agent/app/journey.py new file mode 100644 index 0000000..5f18922 --- /dev/null +++ b/services/batch-agent/app/journey.py @@ -0,0 +1,385 @@ +"""Chatbot Journey — guided conversation to build an agent prompt_template. + +Adapted for Batch Agent Service: imports from app.agents.filesystem_agent +and app.llm instead of monolith paths. Session state is in-memory (could +be moved to Redis for horizontal scaling in the future). + +Journey flow: + 1. Redis consumer dispatches ``journey_start`` with basic agent config. + 2. Server creates an in-memory session, runs the setup LLM with + file-system tools to explore the directory, returns first question. + 3. ``journey_message`` frames drive the conversation. + 4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block. + 5. Server parses the block and returns ``journey_reply`` with ``done=True``. +""" + +from __future__ import annotations + +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage + +from app.agents.filesystem_agent import FILESYSTEM_TOOLS +from app.llm import get_llm + +logger = logging.getLogger(__name__) + +# ── Session TTL ─────────────────────────────────────────────────────────── + +_SESSION_TTL_SECONDS: int = 1800 # 30 minutes + +# Sentinel strings used to delimit the LLM-produced prompt_template. +_TEMPLATE_START = "PROMPT_TEMPLATE_START" +_TEMPLATE_END = "PROMPT_TEMPLATE_END" + +_MIN_TURNS_BEFORE_NUDGE: int = 3 +_MAX_TURNS: int = 15 +_MAX_TOOL_STEPS: int = 6 + +# ── In-memory session store ─────────────────────────────────────────────── + + +@dataclass +class JourneySession: + session_id: str + user_id: str + agent_type: str # "local" | "cloud" + directory: str + data_types: list[str] + history: list[dict[str, Any]] = field(default_factory=list) + system_prompt: str = "" + created_at: float = field(default_factory=time.monotonic) + + def is_expired(self) -> bool: + return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS + + +# session_id → session +_sessions: dict[str, JourneySession] = {} + + +def get_journey_session(session_id: str, user_id: str) -> JourneySession | None: + """Retrieve session; return None on missing, expired, or wrong owner.""" + s = _sessions.get(session_id) + if s is None or s.is_expired(): + _sessions.pop(session_id, None) + return None + if s.user_id != user_id: + return None + return s + + +# ── System prompt builder ───────────────────────────────────────────────── + +_SYSTEM_PROMPT_TEMPLATE = """\ +You are a friendly assistant helping a freelancer configure a data-extraction agent. +Your job is to understand exactly what data the user wants to extract from their +local directory and produce a detailed prompt_template that a separate AI will use +as its instruction set. + +The extraction agent already has this base behaviour built in: + - Reads each file using file-system tools. + - Creates records (tasks, notes, timelines, projects) via CRUD tools. + - Sets isAiSuggested=1 on every new record. + - Only extracts data explicitly present in the files — it never invents information. +The user's custom prompt is appended AFTER this base behaviour, so focus on +what to look for and how to map it — not on the general extraction mechanics. + +You have access to file-system tools to explore the user's directory: +- list_directory: to see folder structure +- read_file_content: to peek at file contents +- get_file_metadata: to check file info + +The user's configured directory is: {directory} +Target data types: {data_types} + +IMPORTANT — project assignment is handled automatically by the main agent runner +before the custom prompt is ever used. You MUST NOT ask the user about projects, +projectId, or how to link records to projects. Never include projectId logic or +project creation instructions in the generated prompt_template. + +Start by exploring the directory to understand its structure. Then ask concise, +focused questions one at a time. Cover these topics (not necessarily in this order): + 1. The type and format of the source content (confirmed by your exploration). + 2. How fields should be mapped (e.g. filename → task title). + 3. Priority or status rules (e.g. "urgent" keyword → high priority). + 4. Any special handling, date extraction, or exclusions. + +Once you reach 90% confidence, output the final prompt_template between these exact +markers on their own lines: + +{template_start} + +{template_end} + +The prompt_template must be a self-contained instruction for an AI that reads files +and must perform CRUD operations using tools to create records. It should specify: + - What entity types to create (tasks, notes, timelines) — never projects. + - How to map file content to record fields (camelCase: title, status, priority, + dueDate, content, etc.) — never include projectId. + - That isAiSuggested must be set to 1 on every new record. + - Concrete examples of mappings based on what you discovered in the directory. + +{existing_section}\ +Keep asking clarifying questions until you are at least 90% confident you have +enough information to generate an accurate prompt_template. Once you reach that +confidence level, stop asking and produce the final template immediately. +Begin by exploring the directory, then ask your first question.\ +""" + + +def _build_system_prompt( + directory: str, + data_types: list[str], + existing_template: str | None = None, +) -> str: + existing_section = ( + f"\nThe user already has the following prompt_template — refine it based on their answers:\n" + f"---\n{existing_template}\n---\n" + if existing_template + else "" + ) + return _SYSTEM_PROMPT_TEMPLATE.format( + directory=directory, + data_types=", ".join(data_types), + template_start=_TEMPLATE_START, + template_end=_TEMPLATE_END, + existing_section=existing_section, + ) + + +# ── Template extraction ─────────────────────────────────────────────────── + + +def _extract_template(text: str) -> str | None: + """Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None.""" + if _TEMPLATE_START not in text or _TEMPLATE_END not in text: + return None + start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START) + end_idx = text.index(_TEMPLATE_END) + return text[start_idx:end_idx].strip() or None + + +# ── LLM call with tool support ─────────────────────────────────────────── + + +def _as_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + return str(content) + + +async def _call_llm_with_tools( + system_prompt: str, + history: list[dict[str, Any]], + tools: list[Any], +) -> str: + """Build LangChain messages from history and invoke the LLM with tools. + + Handles tool-calling loops: if the LLM calls tools, execute them and + continue until a final text response is produced. + """ + messages: list[Any] = [SystemMessage(content=system_prompt)] + for turn in history: + if turn["role"] == "user": + messages.append(HumanMessage(content=turn["content"])) + else: + messages.append(AIMessage(content=turn["content"])) + + llm = get_llm(model=None, temperature=0.4) + llm_with_tools = llm.bind_tools(tools) + tool_map = {tool_def.name: tool_def for tool_def in tools} + + for _ in range(_MAX_TOOL_STEPS): + response: AIMessage = await llm_with_tools.ainvoke(messages) + messages.append(response) + + if not response.tool_calls: + return _as_text(response.content) + + for call in response.tool_calls: + call_name = str(call.get("name", "")) + call_args = call.get("args", {}) + logger.info( + "journey: tool_call name=%s args=%s", + call_name, + json.dumps(call_args, ensure_ascii=True)[:500], + ) + + tool_fn = tool_map.get(call_name) + if tool_fn is None: + tool_output = f"Unknown tool: {call_name}" + else: + tool_output = await tool_fn.ainvoke(call_args) + + logger.info( + "journey: tool_result name=%s output=%s", + call_name, + str(tool_output)[:800], + ) + messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) + + # Fallback: exceeded max tool steps. + final = await llm.ainvoke(messages) + return _as_text(final.content) + + +# ── Journey handlers (called from redis_consumer) ──────────────────────── + + +async def handle_journey_start( + user_id: str, + frame: dict[str, Any], +) -> dict[str, Any]: + """Handle a ``journey_start`` request. + + Creates a session, runs the setup LLM with directory exploration, + and returns the ``journey_reply`` payload. + """ + agent_type = frame.get("agent_type", "local") + directory = frame.get("directory", "") + data_types = frame.get("data_types", []) + existing_template = frame.get("existing_template") + + session_id = frame.get("session_id") or str(uuid.uuid4()) + system_prompt = _build_system_prompt(directory, data_types, existing_template) + + session = JourneySession( + session_id=session_id, + user_id=user_id, + agent_type=agent_type, + directory=directory, + data_types=data_types, + system_prompt=system_prompt, + ) + + seed_history: list[dict[str, Any]] = [ + {"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."}, + ] + ai_reply = await _call_llm_with_tools( + system_prompt=system_prompt, + history=seed_history, + tools=list(FILESYSTEM_TOOLS), + ) + + session.history.extend(seed_history) + session.history.append({"role": "assistant", "content": ai_reply}) + _sessions[session_id] = session + + logger.info( + "journey: session %s started for user %s (directory=%s)", + session_id, + user_id, + directory, + ) + + prompt_template = _extract_template(ai_reply) + done = prompt_template is not None + + display_message = ai_reply + if done: + display_message = ( + ai_reply[: ai_reply.index(_TEMPLATE_START)].strip() + or "Here is your agent configuration. You can save it or continue refining." + ) + _sessions.pop(session_id, None) + + return { + "type": "journey_reply", + "session_id": session_id, + "message": display_message, + "done": done, + "prompt_template": prompt_template, + } + + +async def handle_journey_message( + user_id: str, + frame: dict[str, Any], +) -> dict[str, Any]: + """Handle a ``journey_message`` request. + + Appends the user message, calls the LLM, and returns the + ``journey_reply`` payload. + """ + session_id = frame.get("session_id", "") + message = frame.get("message", "") + + session = get_journey_session(session_id, user_id) + if session is None: + return { + "type": "journey_reply", + "session_id": session_id, + "message": "Journey session not found or expired. Please start a new setup.", + "done": True, + "prompt_template": None, + } + + session.history.append({"role": "user", "content": message}) + + ai_reply = await _call_llm_with_tools( + system_prompt=session.system_prompt, + history=session.history, + tools=list(FILESYSTEM_TOOLS), + ) + + session.history.append({"role": "assistant", "content": ai_reply}) + + prompt_template = _extract_template(ai_reply) + done = prompt_template is not None + + if not done: + turns = sum(1 for t in session.history if t["role"] == "user") + if turns >= _MAX_TURNS: + nudge_content = ( + "[System: You have enough information. Please generate the final " + f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]" + ) + session.history.append({"role": "user", "content": nudge_content}) + + nudge_reply = await _call_llm_with_tools( + system_prompt=session.system_prompt, + history=session.history, + tools=list(FILESYSTEM_TOOLS), + ) + session.history.append({"role": "assistant", "content": nudge_reply}) + + prompt_template = _extract_template(nudge_reply) + if prompt_template is not None: + done = True + ai_reply = nudge_reply + + display_message = ai_reply + if done: + display_message = ( + ai_reply[: ai_reply.index(_TEMPLATE_START)].strip() + if _TEMPLATE_START in ai_reply + else "Here is your agent configuration. You can save it or continue refining." + ) + _sessions.pop(session_id, None) + logger.info("journey: session %s completed for user %s", session_id, user_id) + + return { + "type": "journey_reply", + "session_id": session_id, + "message": display_message, + "done": done, + "prompt_template": prompt_template, + } diff --git a/services/batch-agent/app/llm.py b/services/batch-agent/app/llm.py new file mode 100644 index 0000000..929b358 --- /dev/null +++ b/services/batch-agent/app/llm.py @@ -0,0 +1,76 @@ +"""LLM factory — centralised model instantiation via LiteLLM. + +Identical to services/chat/app/llm.py. Uses shared.config.settings. +""" + +from __future__ import annotations + +import os +import warnings + +from openai import AsyncOpenAI +import litellm + +from langchain_openai import ChatOpenAI +from langchain_litellm import ChatLiteLLM + +from shared.config import settings + +litellm.drop_params = True + +warnings.filterwarnings( + "ignore", + message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`", + category=UserWarning, +) + + +def _api_key_for_model(model: str) -> str | None: + if model.startswith("anthropic/"): + return settings.ANTHROPIC_API_KEY or None + if model.startswith("gemini/") or model.startswith("google/"): + return settings.GOOGLE_API_KEY or None + if model.startswith("cerebras/"): + return settings.CEREBRAS_API_KEY or None + if model.startswith("github_copilot/"): + return None + return settings.OPENAI_API_KEY or None + + +def get_llm( + *, + model: str | None = None, + temperature: float = 0, +) -> ChatOpenAI | ChatLiteLLM: + model = model or settings.LLM_MODEL + + if settings.GITHUB_COPILOT_TOKEN_DIR: + os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR) + + if "/" in model: + return ChatLiteLLM(model=model, temperature=temperature) + + return ChatOpenAI( + model=model, + temperature=temperature, + api_key=_api_key_for_model(model), + ) + + +def get_router_llm( + *, + temperature: float = 0, +) -> ChatOpenAI | ChatLiteLLM: + return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature) + + +async def embed(text: str) -> list[float]: + model = settings.LLM_EMBED_MODEL + + if model.startswith("github_copilot/") or "/" in model: + response = await litellm.aembedding(model=model, input=[text]) + return response.data[0]["embedding"] + + client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) + response = await client.embeddings.create(model=model, input=text) + return response.data[0].embedding diff --git a/services/batch-agent/app/main.py b/services/batch-agent/app/main.py new file mode 100644 index 0000000..52f9a82 --- /dev/null +++ b/services/batch-agent/app/main.py @@ -0,0 +1,57 @@ +"""Batch Agent Service — FastAPI application. + +Owns: agent_runner (local directory + cloud connectors), journey builder, +filesystem_agent, integrations (Gmail, MS Graph). + +Communicates with WS Gateway via Redis: + - Subscribes to batch:request:{user_id} (journey_start, journey_message) + - Publishes to ws:out:{user_id} (journey replies + tool calls) + - BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout) + - SET+EX on journey:{user_id} (journey session state, TTL 1800s) +""" + +from __future__ import annotations + +import asyncio +import logging + +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.redis_consumer import start_consumer +from app.routes import router + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + logger.info("batch-agent: starting Redis consumer") + task = asyncio.create_task(start_consumer()) + yield + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + logger.info("batch-agent: Redis consumer stopped") + + +app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["GET", "POST"], + allow_headers=["*"], +) + +app.include_router(router) + + +@app.get("/health") +async def health() -> dict[str, str]: + return {"status": "ok", "service": "batch-agent"} diff --git a/services/batch-agent/app/redis_consumer.py b/services/batch-agent/app/redis_consumer.py new file mode 100644 index 0000000..d0947d9 --- /dev/null +++ b/services/batch-agent/app/redis_consumer.py @@ -0,0 +1,141 @@ +"""Redis consumer for the Batch Agent Service. + +Subscribes to batch:request:* (pattern) and dispatches: + - journey_start → handle_journey_start + - journey_message → handle_journey_message + - agent_trigger → run_local_agent / run_cloud_agent + +Results are published back to ws:out:{user_id} via Redis. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +from shared.redis import redis_client, batch_request_channel, ws_out_channel + +from app.ws_context import set_current_user, clear_current_user + +logger = logging.getLogger(__name__) + + +async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None: + """Publish a frame to the user's WS outbound channel.""" + channel = ws_out_channel(user_id) + await redis_client.publish(channel, json.dumps(payload)) + + +async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None: + """Handle a journey_start request from WS Gateway.""" + from app.journey import handle_journey_start + + set_current_user(user_id) + try: + reply = await handle_journey_start(user_id, data) + await _publish_to_user(user_id, reply) + except Exception as exc: + logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc) + await _publish_to_user(user_id, { + "type": "journey_reply", + "session_id": data.get("session_id", ""), + "message": f"Journey setup failed: {exc}", + "done": True, + "prompt_template": None, + }) + finally: + clear_current_user() + + +async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None: + """Handle a journey_message from WS Gateway.""" + from app.journey import handle_journey_message + + set_current_user(user_id) + try: + reply = await handle_journey_message(user_id, data) + await _publish_to_user(user_id, reply) + except Exception as exc: + logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc) + await _publish_to_user(user_id, { + "type": "journey_reply", + "session_id": data.get("session_id", ""), + "message": f"Journey processing failed: {exc}", + "done": True, + "prompt_template": None, + }) + finally: + clear_current_user() + + +async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None: + """Handle an agent_trigger request from the REST route (forwarded via Redis).""" + from app.agent_runner import run_local_agent + + set_current_user(user_id) + try: + await run_local_agent(user_id, data) + except Exception as exc: + logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc) + await _publish_to_user(user_id, { + "type": "run_complete", + "status": "error", + "run_context": data.get("run_context", {}), + }) + finally: + clear_current_user() + + +async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None: + """Route a batch request to the correct handler.""" + msg_type = message_data.get("type", "") + + if msg_type == "journey_start": + await _handle_journey_start(user_id, message_data) + elif msg_type == "journey_message": + await _handle_journey_message(user_id, message_data) + elif msg_type == "agent_trigger": + await _handle_agent_trigger(user_id, message_data) + else: + logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id) + + +async def start_consumer() -> None: + """Subscribe to batch:request:* and dispatch incoming frames.""" + pubsub = redis_client.pubsub() + await pubsub.psubscribe("batch:request:*") + logger.info("batch-agent: subscribed to batch:request:*") + + try: + async for message in pubsub.listen(): + if message["type"] != "pmessage": + continue + + channel: str = message["channel"] + if isinstance(channel, bytes): + channel = channel.decode() + + # Extract user_id from channel: batch:request:{user_id} + parts = channel.split(":", 2) + if len(parts) < 3: + continue + user_id = parts[2] + + raw = message["data"] + if isinstance(raw, bytes): + raw = raw.decode() + + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("batch-agent: invalid JSON on channel %s", channel) + continue + + # Dispatch in a separate task to avoid blocking the consumer + asyncio.create_task(_dispatch(user_id, data)) + except asyncio.CancelledError: + logger.info("batch-agent: consumer shutting down") + finally: + await pubsub.punsubscribe("batch:request:*") diff --git a/services/batch-agent/app/routes.py b/services/batch-agent/app/routes.py new file mode 100644 index 0000000..65e0e48 --- /dev/null +++ b/services/batch-agent/app/routes.py @@ -0,0 +1,208 @@ +"""Agent REST routes — catalog, billing checks, trigger. + +Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas. +Agent trigger dispatches via Redis to the consumer instead of spawning +an in-process background task. +""" + +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, Header, HTTPException, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.db import async_session +from shared.models import AgentRunLog +from shared.redis import redis_client, batch_request_channel + +from app.agent_runner import is_agent_running + +router = APIRouter(prefix="/agents", tags=["agents"]) + +# ── Tier feature limits ─────────────────────────────────────────────── +# Mirrors app/billing/tier_manager.py FEATURES dict. +FEATURES: dict[str, dict] = { + "free": {"batch_active": 1, "batch_runs_per_day": 3}, + "pro": {"batch_active": 5, "batch_runs_per_day": 20}, + "power": {"batch_active": 20, "batch_runs_per_day": 100}, + "team": {"batch_active": -1, "batch_runs_per_day": -1}, +} + + +def _dt_ms(dt: datetime) -> int: + return int(dt.timestamp() * 1000) + + +def _dt_ms_opt(dt: datetime | None) -> int | None: + return int(dt.timestamp() * 1000) if dt else None + + +def _to_data_types(values: list[str]) -> list[str]: + normalize = { + "task": "tasks", "tasks": "tasks", + "note": "notes", "notes": "notes", + "timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines", + "project": "projects", "projects": "projects", + } + seen: set[str] = set() + result: list[str] = [] + for v in values: + mapped = normalize.get(v) + if mapped and mapped not in seen: + seen.add(mapped) + result.append(mapped) + return result + + +def _enforce_agent_limit(tier: str, current_count: int) -> int: + limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] + if limit != -1 and current_count >= limit: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", + ) + return limit + + +async def _enforce_run_frequency(tier: str, user_id: str) -> None: + limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"] + if limit == -1: + return + today_start = datetime.now(timezone.utc).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + async with async_session() as db: + result = await db.execute( + select(func.count(AgentRunLog.id)).where( + AgentRunLog.user_id == user_id, + AgentRunLog.started_at >= today_start, + ) + ) + runs_today: int = result.scalar_one() + + if runs_today >= limit: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Daily batch run limit ({limit}) reached for your tier.", + ) + + +# ── Catalog ─────────────────────────────────────────────────────────── + +@router.get("/catalog") +async def get_agent_catalog( + x_user_id: str = Header(..., alias="X-User-Id"), +) -> list[dict]: + return [ + { + "type": "local_directory", + "name": "Local Directory Monitor", + "description": "Watches local directories, extracts data from files using AI", + }, + { + "type": "gmail", + "name": "Gmail Connector", + "description": "Scans Gmail inbox, extracts tasks/notes from emails", + }, + { + "type": "teams", + "name": "Microsoft Teams Connector", + "description": "Monitors Teams messages, extracts action items", + }, + { + "type": "outlook", + "name": "Outlook Connector", + "description": "Scans Outlook inbox, extracts tasks/notes", + }, + ] + + +# ── Can-create check ───────────────────────────────────────────────── + +@router.post("/can-create") +async def can_create_agent( + body: dict, + x_user_id: str = Header(..., alias="X-User-Id"), + x_user_tier: str = Header("free", alias="X-User-Tier"), +) -> dict: + active_agents = body.get("active_agents", 0) + limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"] + allowed = limit == -1 or active_agents < limit + return { + "allowed": allowed, + "tier": x_user_tier, + "active_agents": active_agents, + "limit": limit, + } + + +# ── Trigger ────────────────────────────────────────────────────────── + +@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED) +async def trigger_agent_run( + body: dict, + x_user_id: str = Header(..., alias="X-User-Id"), + x_user_tier: str = Header("free", alias="X-User-Tier"), +) -> dict: + """Trigger a local agent run — creates run log and dispatches via Redis.""" + active_agents = body.get("active_agents", 0) + _enforce_agent_limit(x_user_tier, active_agents) + await _enforce_run_frequency(x_user_tier, x_user_id) + + stable_agent_id = body.get("agent_id") or str(uuid.uuid4()) + + if is_agent_running(stable_agent_id): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Agent is already running.", + ) + + # Create run log in DB + async with async_session() as db: + run_log = AgentRunLog( + agent_id=stable_agent_id, + agent_type="local", + user_id=x_user_id, + status="running", + ) + db.add(run_log) + await db.commit() + await db.refresh(run_log) + run_log_id = run_log.id + + run_context = { + "type": "agent_batch", + "run_id": run_log_id, + "agent_id": stable_agent_id, + } + + # Dispatch to the Redis consumer for processing + trigger_data = { + "type": "agent_trigger", + "directory": body.get("directory", ""), + "directory_paths": [body.get("directory", "")] if body.get("directory") else [], + "data_types": _to_data_types(body.get("what_to_extract", [])), + "file_extensions": body.get("file_extensions", []), + "prompt_template": body.get("custom_agent_prompt", ""), + "device_id": body.get("device_id", ""), + "run_context": run_context, + } + + channel = batch_request_channel(x_user_id) + await redis_client.publish(channel, json.dumps(trigger_data)) + + return { + "id": run_log_id, + "agent_id": stable_agent_id, + "agent_type": "local", + "status": "running", + "items_processed": 0, + "items_created": 0, + "errors": [], + "started_at": _dt_ms(run_log.started_at), + "completed_at": None, + } diff --git a/services/batch-agent/app/ws_context.py b/services/batch-agent/app/ws_context.py new file mode 100644 index 0000000..ea3694b --- /dev/null +++ b/services/batch-agent/app/ws_context.py @@ -0,0 +1,135 @@ +"""WebSocket context for Batch Agent Service — Redis-based tool call round-trip. + +Same pattern as services/chat/app/ws_context.py: publishes tool_call frames +to Redis ws:out:{user_id} and awaits BRPOP on tool:result:{call_id}. + +Additionally provides set_client_executor / clear_client_executor stubs +for backward compatibility with the agent_runner code (which originally +used a DeviceConnectionManager callback). In the microservice world these +are no-ops — execute_on_client() always uses the Redis path. +""" + +from __future__ import annotations + +import json +import logging +from contextvars import ContextVar +from typing import Any, Callable, Coroutine +from uuid import uuid4 + +from shared.redis import redis_client, tool_result_key, ws_out_channel + +logger = logging.getLogger(__name__) + +_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout + +# Per-request user_id context var (set before agent run) +_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None) + +# Optional collector for debug / logging +_tool_result_collector: ContextVar[list[dict] | None] = ContextVar( + "_tool_result_collector", default=None +) + + +def set_current_user(user_id: str) -> None: + _current_user_id.set(user_id) + + +def clear_current_user() -> None: + _current_user_id.set(None) + + +def set_tool_result_collector(lst: list[dict]) -> None: + _tool_result_collector.set(lst) + + +def clear_tool_result_collector() -> None: + _tool_result_collector.set(None) + + +# ── Compatibility shims ────────────────────────────────────────────────── +# agent_runner.py originally called set_client_executor / clear_client_executor +# with a DeviceConnectionManager callback. In the microservice world the +# Redis-based execute_on_client replaces this, so these are no-ops that +# keep the agent_runner code unchanged. + +def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None: + """No-op — kept for agent_runner compatibility.""" + pass + + +def clear_client_executor() -> None: + """No-op — kept for agent_runner compatibility.""" + pass + + +async def execute_on_client( + action: str, + table: str | None = None, + data: dict[str, Any] | None = None, + filters: dict[str, Any] | None = None, + vector: list[float] | None = None, + limit: int | None = None, +) -> dict[str, Any]: + """Send a tool_call to Electron via Redis and await the result. + + 1. Build tool_call payload + 2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron) + 3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies) + 4. Return result dict + + Raises RuntimeError if no user_id is set or if the call times out. + """ + user_id = _current_user_id.get() + if not user_id: + raise RuntimeError( + "execute_on_client() called without a user_id — " + "set_current_user() must be called first." + ) + + call_id = str(uuid4()) + payload: dict[str, Any] = { + "type": "tool_call", + "id": call_id, + "action": action, + } + if table is not None: + payload["table"] = table + if data is not None: + payload["data"] = data + if filters is not None: + payload["filters"] = {k: v for k, v in filters.items() if v is not None} + if vector is not None: + payload["vector"] = vector + if limit is not None: + payload["limit"] = limit + + # Publish tool_call to WS Gateway → Electron + channel = ws_out_channel(user_id) + await redis_client.publish(channel, json.dumps(payload)) + + # Wait for Electron's tool_result + result_key = tool_result_key(call_id) + response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT) + + if response is None: + raise RuntimeError( + f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — " + f"device may be offline or unresponsive." + ) + + # response is (key, value) tuple + _, raw = response + result = json.loads(raw) + + # Collect for debug if requested + collector = _tool_result_collector.get(None) + if collector is not None: + collector.append({ + "action": action, + "table": table, + "data": result, + }) + + return result diff --git a/services/batch-agent/requirements.txt b/services/batch-agent/requirements.txt new file mode 100644 index 0000000..42e9b67 --- /dev/null +++ b/services/batch-agent/requirements.txt @@ -0,0 +1,20 @@ +fastapi>=0.115.0 +uvicorn[standard]>=0.34.0 +gunicorn>=22.0.0 +pydantic>=2.10.0 +pydantic-settings>=2.7.0 +sqlalchemy>=2.0.0 +asyncpg>=0.30.0 +redis>=5.0.0 +cryptography>=42.0.0 +python-dotenv>=1.0.0 +langchain-core>=0.3.0 +langchain-openai>=0.3.0 +langchain-litellm>=0.3.0 +litellm>=1.50.0 +openai>=1.50.0 +httpx>=0.27.0 +croniter>=2.0.0 +google-api-python-client>=2.130.0 +google-auth>=2.30.0 +msal>=1.28.0