feat(batch-agent): extract Batch Agent Service (Step 3)

- agent_runner: local directory + cloud agent orchestration via Redis
- 5 domain agents: filesystem, task, note, project, timeline
- integrations: Gmail, MS Graph (Outlook + Teams)
- journey: guided chatbot conversation to build prompt_template
- routes: REST endpoints (catalog, can-create, trigger)
- redis_consumer: subscribes to batch:request:* pattern
- ws_context: Redis-based execute_on_client for tool round-trip
- Dockerfile with 300s timeout for long-running batch jobs
This commit is contained in:
Roberto Musso
2026-03-23 07:19:02 +01:00
parent 229e20d073
commit 333bba6fdd
18 changed files with 3157 additions and 0 deletions

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/batch-agent/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/batch-agent/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "300"]

View File

@@ -0,0 +1,884 @@
"""Agent run orchestrator — adapted for Batch Agent Service.
Key changes from monolith app/core/agent_runner.py:
- No DeviceConnectionManager — tool calls go through Redis ws_context.
- set_current_user / clear_current_user replace set_client_executor.
- run_local_agent accepts a serialized dict (from Redis / REST) instead
of SQLAlchemy model objects.
- _finalize_run writes to PostgreSQL via shared.db.async_session.
- Cloud agent import path changed to app.integrations.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from sqlalchemy import select
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.llm import get_llm
from app.ws_context import execute_on_client, set_current_user, clear_current_user
from shared.db import async_session
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from shared.redis import redis_client, ws_out_channel
logger = logging.getLogger(__name__)
# ── Concurrency guard ─────────────────────────────────────────────────────
_running_agents: set[str] = set()
def is_agent_running(agent_id: str) -> bool:
return agent_id in _running_agents
# ── Timeouts ───────────────────────────────────────────────────────────────
_TOOL_CALL_TIMEOUT: int = 30
_MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ─────────────────────────────────────────────
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
}
# ── Step 1: Classification prompt ─────────────────────────────────────────
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
"tasks": (
"Action items, to-dos, deliverables — anything that describes work to be done, "
"assigned to someone, or tracked with a due date or status."
),
"notes": (
"Documentation, meeting notes, summaries, reference material — "
"written content meant to be read and referenced rather than acted on."
),
"timelines": (
"Project milestones, deadlines, scheduled events — "
"specific dates that mark a point in the progress of a project."
),
"projects": (
"High-level project entities — only relevant if the file clearly introduces "
"a new project or updates the scope of an existing one."
),
}
_STEP1_SYSTEM_PROMPT = """\
You are a file classifier for a freelance project management tool.
Your job is to match a file to an existing project and identify which data domains to extract.
## Project matching rules (STRICT — follow in order)
1. Search the file content for any mention of a project name, client name, acronym, or topic
that overlaps with the existing projects listed below.
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
when the file has zero meaningful connection to any listed project.
4. When in doubt, pick the closest match from the list.
## Response format
Respond ONLY with a JSON object — no markdown, no explanation:
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
## Domain definitions (only consider domains in the allowed list)
{domain_definitions}
## Existing projects
{projects_list}
"""
# ── Step 2: Processing prompt ─────────────────────────────────────────────
_PROCESSING_SYSTEM_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Your task: extract structured data from the file content and persist it using the available tools.
## Mandatory process — follow this order for EVERY item you extract
1. READ the existing records listed below for the relevant domain.
2. SEARCH for a match by title, topic, or semantic similarity.
3. If a match exists → call the update_* tool with the existing record's id.
4. If no match exists → call the create_* tool and set isAiSuggested=1.
NEVER call create_* without first checking the existing records.
NEVER duplicate a record that already exists under a different wording.
## Existing records (source of truth)
{existing_context}
## Context
Project: {project_context}
Domains to extract: {data_types}
{custom_prompt_section}
"""
# ── Cloud processing prompt ───────────────────────────────────────────────
_CLOUD_PROCESSING_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
Available tools:
Filesystem : read_file_content, list_directory, get_file_metadata
Tasks : list_tasks, create_task, update_task, add_task_comment
Notes : list_notes, get_note, create_note, update_note
Timelines : list_timelines, create_timeline, update_timeline
Projects : list_all_projects, get_project, create_project, update_project
Your task:
1. Read the full content of each file below using read_file_content.
2. For each piece of information found, ALWAYS try to match and update an
existing record before creating a new one.
3. ONLY act on these entity types: {data_types}.
4. Do NOT invent data. Only extract what is clearly present in the files.
5. If a file contains no relevant data for the target entity types, skip it.
{project_context}
Files to process:
{file_list}
{custom_prompt_section}
After processing all files, respond with a brief summary of what you updated
and what you created.
"""
# ── LLM tool-calling loop ─────────────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _run_agent_with_tools(
*,
system_prompt: str,
user_message: str,
tools: list[Any],
max_steps: int,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response."""
llm = get_llm()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message),
]
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Tool list builder ─────────────────────────────────────────────────────
def _build_processing_tools(data_types: list[str]) -> list[Any]:
tools: list[Any] = list(FILESYSTEM_TOOLS)
for dt in data_types:
dt_tools = _DATA_TYPE_TOOLS.get(dt)
if dt_tools:
tools.extend(dt_tools)
return tools
# ── Code-based directory scanner ─────────────────────────────────────────
async def _scan_directories(
paths: list[str],
extensions: list[str],
last_run_at: datetime | None,
) -> list[str]:
all_files: list[str] = []
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
async def _walk(path: str, depth: int) -> None:
if depth > _MAX_SCAN_DEPTH:
return
try:
result = await execute_on_client(action="list_directory", data={"path": path})
except Exception as exc:
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
return
for entry in result.get("entries", []):
entry_path = entry.get("path", "")
if not entry_path:
continue
if entry.get("type") == "directory":
await _walk(entry_path, depth + 1)
elif entry.get("type") == "file":
if ext_set:
dot_pos = entry_path.rfind(".")
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
if file_ext not in ext_set:
continue
all_files.append(entry_path)
for root in paths:
await _walk(root, depth=0)
if last_run_at is None:
return all_files
last_run_ms = int(last_run_at.timestamp() * 1000)
filtered: list[str] = []
for file_path in all_files:
try:
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
modified_at = meta.get("modifiedAt")
if modified_at is None:
filtered.append(file_path)
continue
if isinstance(modified_at, (int, float)):
mod_ms = int(modified_at)
else:
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
if mod_ms > last_run_ms:
filtered.append(file_path)
except Exception:
filtered.append(file_path)
return filtered
# ── Code-based entity fetchers ────────────────────────────────────────────
async def _fetch_projects() -> list[dict]:
try:
result = await execute_on_client(action="select", table="projects")
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch projects: %s", exc)
return []
_DOMAIN_TABLE: dict[str, str] = {
"tasks": "tasks",
"notes": "notes",
"timelines": "timelines",
"projects": "projects",
}
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
table = _DOMAIN_TABLE.get(domain)
if not table:
return []
filters: dict[str, Any] = {}
if project_id != "standalone" and domain != "projects":
filters["projectId"] = project_id
try:
result = await execute_on_client(
action="select",
table=table,
filters=filters if filters else None,
)
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
return []
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
if not rows:
return f"No existing {domain}."
lines: list[str] = []
for r in rows:
if domain == "tasks":
desc = r.get("description") or ""
desc_part = f"{desc[:120]}" if desc else ""
assignee = r.get("assignee") or r.get("assignees") or ""
due = r.get("dueDate") or r.get("due_date") or ""
meta = ", ".join(filter(None, [
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
f"assignee: {assignee}" if assignee else "",
f"due: {due}" if due else "",
]))
lines.append(
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
f" ({meta}, id: {r['id']})"
)
elif domain == "notes":
snippet = (r.get("content") or "")[:200].replace("\n", " ")
snippet_part = f"\n Preview: {snippet}" if snippet else ""
lines.append(
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
)
elif domain == "timelines":
lines.append(
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
)
elif domain == "projects":
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
summary_part = f"{summary}" if summary else ""
lines.append(
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
f" (id: {r['id']})"
)
return f"Existing {domain}:\n" + "\n".join(lines)
# ── Step 1: LLM file classifier ───────────────────────────────────────────
async def _classify_file(
file_path: str,
file_content: str,
projects: list[dict],
config_data_types: list[str],
) -> tuple[str, list[str], str | None]:
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
if not file_content.strip():
return fallback
valid_project_ids = {p["id"] for p in projects}
def _fmt_project(p: dict) -> str:
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
summary_part = f"{summary[:100]}" if summary else ""
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
domain_definitions = "\n".join(
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
for d in config_data_types
if d in _DOMAIN_DESCRIPTIONS
)
system = _STEP1_SYSTEM_PROMPT.format(
domain_definitions=domain_definitions,
projects_list=projects_list,
)
llm = get_llm()
try:
response = await llm.ainvoke([
SystemMessage(content=system),
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
])
raw = _as_text(response.content).strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
raw_project_id: str = str(parsed.get("project_id") or "new")
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
new_project_name: str | None = (
str(parsed["new_project_name"]).strip() or None
if project_id == "new" and parsed.get("new_project_name")
else None
)
domains: list[str] = [
d for d in parsed.get("domains", [])
if d in config_data_types
]
if not domains:
domains = list(config_data_types)
return project_id, domains, new_project_name
except Exception as exc:
logger.warning(
"agent_runner: step1 classification failed for %r: %s", file_path, exc
)
return fallback
# ── Local agent runner (two-step per file) ────────────────────────────────
async def run_local_agent(user_id: str, trigger_data: dict[str, Any]) -> None:
"""Execute a local directory agent run.
In the microservice world, trigger_data is a serialized dict from
the REST route (forwarded via Redis), containing the agent config
fields and run_context.
set_current_user() must be called BEFORE this function.
"""
run_context: dict = trigger_data.get("run_context", {})
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
run_id = run_context.get("run_id")
_running_agents.add(agent_id)
# Extract config from trigger payload
directory_paths: list[str] = trigger_data.get("directory_paths", [])
if not directory_paths:
directory = trigger_data.get("directory", "")
if directory:
directory_paths = [directory]
data_types: list[str] = trigger_data.get("data_types", [])
file_extensions: list[str] = trigger_data.get("file_extensions", [])
prompt_template: str = trigger_data.get("prompt_template", "")
last_run_at_raw = trigger_data.get("last_run_at")
last_run_at: datetime | None = None
if last_run_at_raw:
if isinstance(last_run_at_raw, str):
last_run_at = datetime.fromisoformat(last_run_at_raw)
elif isinstance(last_run_at_raw, (int, float)):
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
errors: list[str] = []
items_processed = 0
items_created = 0
custom_section = (
f"User instructions:\n{prompt_template}"
if prompt_template
else ""
)
# Create or load run log
run_log_id = run_id
if not run_log_id:
async with async_session() as db:
run_log = AgentRunLog(
agent_id=agent_id,
agent_type="local",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
try:
# ── Scan directories ─────────────────────────────────────────
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
file_paths = await _scan_directories(
paths=directory_paths,
extensions=file_extensions,
last_run_at=last_run_at,
)
logger.info(
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
)
if not file_paths:
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
return
# ── Fetch all projects once ──────────────────────────────────
projects = await _fetch_projects()
for file_path in file_paths:
try:
file_result = await execute_on_client(
action="read_file_content", data={"path": file_path}
)
file_content: str = file_result.get("content", "")
if not file_content:
continue
items_processed += 1
# Step 1 — classify file
project_id, domains, new_project_name = await _classify_file(
file_path=file_path,
file_content=file_content,
projects=projects,
config_data_types=data_types,
)
# Step 2 — resolve project_id, fetch entities, process
if project_id == "new":
proj_name = new_project_name or "Untitled Project"
try:
proj_result = await execute_on_client(
action="insert",
table="projects",
data={"name": proj_name, "clientId": None},
)
created = proj_result.get("row", {})
effective_project_id = created.get("id", "standalone")
if "id" in created:
projects.append(created)
except Exception as exc:
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
effective_project_id = "standalone"
proj_name = "unknown"
project_context = (
f"Project: {proj_name} (id: {effective_project_id}). "
"Always set projectId to this id on every record you create."
)
else:
effective_project_id = project_id
proj = next((p for p in projects if p["id"] == project_id), None)
proj_name = proj.get("name", project_id) if proj else project_id
project_context = (
f"Project: {proj_name} (id: {project_id}). "
"Always set projectId to this id on every record you create."
)
domains = [d for d in domains if d != "projects"]
existing_blocks: list[str] = []
for domain in domains:
rows = await _fetch_domain_entities(domain, effective_project_id)
existing_blocks.append(_format_entities_for_context(domain, rows))
existing_context = "\n\n".join(existing_blocks)
system_prompt = _PROCESSING_SYSTEM_PROMPT.format(
existing_context=existing_context,
project_context=project_context,
data_types=", ".join(domains),
custom_prompt_section=custom_section,
)
processing_tools = _build_processing_tools(domains)
result_text = await _run_agent_with_tools(
system_prompt=system_prompt,
user_message=(
f"Process this file and extract relevant information.\n\n"
f"File: {file_path}\n\nContent:\n{file_content}"
),
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
logger.info(
"agent_runner: run=%s file=%r result=%s",
run_log_id, file_path, result_text[:200],
)
except Exception as exc:
errors.append(f"Error processing '{file_path}': {exc}")
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
finally:
_running_agents.discard(agent_id)
# ── Finalise ────────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
)
# Notify Electron that the run is complete via Redis
if run_context:
try:
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps({
"type": "run_complete",
"run_context": run_context,
"status": final_status,
}))
except Exception as exc:
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
# ── Cloud agent runner ─────────────────────────────────────────────────────
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent(user_id: str, config_id: str) -> None:
"""Execute a cloud connector agent run.
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
messages from the provider, and runs LLM extraction.
set_current_user() must be called BEFORE this function.
"""
from app.integrations import decrypt_token, encrypt_token, get_provider
async with async_session() as db:
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
config = result.scalar_one_or_none()
if config is None:
logger.error("agent_runner: cloud config %s not found", config_id)
return
# Create run log
run_log = AgentRunLog(
agent_id=config.id,
agent_type="cloud",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
# ── Decrypt OAuth token ────────────────────────────────────────
if not config.oauth_token_encrypted:
await _finalize_run(
run_log_id,
status="error",
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
)
return
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Failed to decrypt OAuth token: {exc}"],
)
return
# ── Instantiate provider ──────────────────────────────────────
try:
provider = get_provider(config.provider, credentials_info)
except ValueError as exc:
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
return
# ── Fetch messages ────────────────────────────────────────────
since: datetime | None = config.last_run_at
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
if since.tzinfo is None:
since = since.replace(tzinfo=timezone.utc)
errors: list[str] = []
items_processed = 0
try:
if config.provider == "gmail":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "outlook":
raw_messages = await provider.fetch_emails(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "teams":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
else:
raw_messages = []
except RuntimeError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Provider fetch failed: {exc}"],
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
return
logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s",
config.id, len(raw_messages), config.provider,
)
# ── Extract + insert via LLM ─────────────────────────────────
try:
processing_tools = _build_processing_tools(config.data_types)
custom_section = (
f"User instructions:\n{config.prompt_template}"
if config.prompt_template
else ""
)
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
processing_prompt = _CLOUD_PROCESSING_PROMPT.format(
data_types=", ".join(config.data_types),
project_context="Determine the appropriate project from the message context.",
file_list=f"Message from {config.provider} (id: {msg.id})",
custom_prompt_section=custom_section,
)
try:
await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
# ── Persist refreshed token ───────────────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
if refreshed:
try:
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
except Exception as exc:
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
# ── Finalise ──────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=0,
errors=errors,
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
# ── Internal helper ─────────────────────────────────────────────────────────
async def _finalize_run(
run_log_id: int | str,
*,
status: str,
items_processed: int = 0,
items_created: int = 0,
errors: list[str] | None = None,
update_config_last_run: bool = False,
config_id: str | None = None,
config_type: str | None = None,
) -> None:
"""Persist the run outcome and optionally update last_run_at on the config."""
now = datetime.now(timezone.utc)
try:
async with async_session() as db:
result = await db.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
)
managed = result.scalar_one_or_none()
if managed is None:
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
return
managed.status = status
managed.items_processed = items_processed
managed.items_created = items_created
managed.errors = errors or []
managed.completed_at = now
if update_config_last_run and config_id:
if config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit()
except Exception as exc:
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)

View File

@@ -0,0 +1 @@
"""Batch Agent Service domain agents and filesystem tools."""

View File

@@ -0,0 +1,83 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -0,0 +1,110 @@
"""Note agent — Markdown note management.
Adapted for Batch Agent Service: import from app.ws_context and app.llm.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.llm import embed
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
async def create_note(title: str, content: str, project_id: str = "") -> str:
"""Create a new note."""
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
)
row = result["row"]
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool
async def update_note(note_id: str, title: str = "", content: str = "") -> str:
"""Update an existing note. Only pass fields that should change."""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
if content:
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -0,0 +1,110 @@
"""Project agent — full lifecycle management.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
@tool
async def list_projects(client_id: str = "", include_archived: int = 0) -> str:
"""List projects, optionally filtered by client_id."""
result = await execute_on_client(
action="select",
table="projects",
filters={
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
)
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status."""
result = await execute_on_client(action="select", table="projects")
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
row = result.get("row")
if not row:
return f"Project {project_id} not found."
return (
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
f"clientId: {row.get('clientId', 'none')})"
)
@tool
async def create_project(name: str, client_id: str = "") -> str:
"""Create a new project."""
result = await execute_on_client(
action="insert",
table="projects",
data={"name": name, "clientId": client_id or None},
)
row = result["row"]
return f"Project created: '{row['name']}' (id: {row['id']})"
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change."""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
result = await execute_on_client(
action="update",
table="projects",
data={"id": project_id, "updates": updates},
)
row = result["row"]
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project."""
await execute_on_client(action="delete", table="projects", data={"id": project_id})
return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -0,0 +1,197 @@
"""Task agent — full CRUD for tasks and task comments.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status, search, or order_by."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
) -> str:
"""Create a new task."""
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
)
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
) -> str:
"""Update fields on an existing task. Only pass fields you want to change."""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
result = await execute_on_client(
action="select",
table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task."""
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -0,0 +1,88 @@
"""Timeline agent — project milestone management.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No timelines found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
@tool
async def create_timeline(
project_id: str, title: str, date: int, is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline (milestone)."""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
@tool
async def update_timeline(timeline_id: str, title: str = "", date: int = -1) -> str:
"""Update a timeline. Only pass fields that should change."""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -0,0 +1,108 @@
"""Cloud provider integration utilities.
Adapted for Batch Agent Service: import from shared.config instead of app.config.
Provides:
* Shared message dataclasses (EmailMessage, ChatMessage)
* get_provider() — factory for Gmail/MS Graph clients
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from shared.config import settings
if TYPE_CHECKING:
from app.integrations.gmail import GmailClient
from app.integrations.ms_graph import MSGraphClient
logger = logging.getLogger(__name__)
@dataclass
class EmailMessage:
id: str
subject: str
sender: str
body_text: str
date: datetime
labels: list[str] = field(default_factory=list)
@property
def as_text(self) -> str:
date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{labels_str}\n"
f"Subject: {self.subject}\n\n"
f"{self.body_text}"
)
@dataclass
class ChatMessage:
id: str
content: str
sender: str
channel: str | None
date: datetime
@property
def as_text(self) -> str:
date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{channel_str}\n\n"
f"{self.content}"
)
def _get_fernet() -> Fernet:
key = settings.OAUTH_ENCRYPTION_KEY
if not key:
raise RuntimeError(
"OAUTH_ENCRYPTION_KEY is not set. "
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
)
return Fernet(key.encode() if isinstance(key, str) else key)
def encrypt_token(token_info: dict) -> str:
if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8")
return _get_fernet().encrypt(plaintext).decode("utf-8")
def decrypt_token(encrypted: str) -> dict:
try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext)
except (InvalidToken, json.JSONDecodeError) as exc:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
def get_provider(
provider: str,
credentials_info: dict,
) -> "GmailClient | MSGraphClient":
if provider == "gmail":
from app.integrations.gmail import GmailClient
return GmailClient(credentials_info)
if provider in {"outlook", "teams"}:
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(credentials_info)
raise ValueError(
f"Unknown cloud provider {provider!r}. "
"Supported: 'gmail', 'outlook', 'teams'."
)

View File

@@ -0,0 +1,252 @@
"""Gmail API client for cloud agent integration.
Adapted for Batch Agent Service: import from app.integrations instead of
app.integrations (same relative path within the service).
"""
from __future__ import annotations
import asyncio
import base64
import email
import html
import logging
import re
from datetime import datetime, timezone
from typing import Any
from app.integrations import EmailMessage
logger = logging.getLogger(__name__)
_GMAIL_DATE_FMT = "%Y/%m/%d"
_BODY_TRUNCATE = 8_000
_MAX_MESSAGES = 200
def _build_gmail_query(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
parts: list[str] = []
cfg = filter_config or {}
labels: list[str] = cfg.get("labels", [])
if labels:
if len(labels) == 1:
parts.append(f"label:{labels[0]}")
else:
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})")
senders: list[str] = cfg.get("senders", [])
for sender in senders:
parts.append(f"from:{sender}")
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
if effective_since:
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
except ValueError:
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
return " ".join(parts)
def _strip_html(raw_html: str) -> str:
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str:
mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", [])
if mime_type == "text/plain":
data = body.get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type == "text/html":
data = body.get("data", "")
if data:
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return _strip_html(raw)
return ""
plain_fallback = ""
for part in parts:
part_mime = part.get("mimeType", "")
if part_mime == "text/plain":
return _parse_body(part)
if part_mime == "text/html" and not plain_fallback:
plain_fallback = _parse_body(part)
if part_mime.startswith("multipart/"):
nested = _parse_body(part)
if nested:
return nested
return plain_fallback
def _parse_date(raw: str) -> datetime:
try:
parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
except Exception:
return datetime.now(timezone.utc)
class GmailClient:
def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials
self._credentials_info = credentials_info
expiry_str: str | None = credentials_info.get("expiry")
expiry: datetime | None = None
if expiry_str:
try:
expiry = datetime.fromisoformat(
expiry_str.replace("Z", "+00:00")
).replace(tzinfo=timezone.utc)
except ValueError:
pass
self._credentials = Credentials(
token=credentials_info.get("token"),
refresh_token=credentials_info.get("refresh_token"),
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=credentials_info.get("client_id"),
client_secret=credentials_info.get("client_secret"),
scopes=credentials_info.get("scopes"),
expiry=expiry,
)
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query)
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
creds = self._credentials
if not creds.valid and creds.expired:
return None
if creds.token != self._credentials_info.get("token"):
result = {
"token": creds.token,
"refresh_token": creds.refresh_token,
"token_uri": creds.token_uri,
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"scopes": list(creds.scopes or []),
}
if creds.expiry:
result["expiry"] = creds.expiry.isoformat()
return result
return None
def _fetch_sync(self, query: str) -> list[EmailMessage]:
import googleapiclient.discovery
import googleapiclient.errors
from google.auth.transport.requests import Request
if self._credentials.expired and self._credentials.refresh_token:
try:
self._credentials.refresh(Request())
except Exception as exc:
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False
)
user_api = service.users()
ids: list[str] = []
page_token: str | None = None
while len(ids) < _MAX_MESSAGES:
batch_size = min(100, _MAX_MESSAGES - len(ids))
kwargs: dict[str, Any] = {
"userId": "me",
"maxResults": batch_size,
}
if query:
kwargs["q"] = query
if page_token:
kwargs["pageToken"] = page_token
try:
resp = user_api.messages().list(**kwargs).execute()
except googleapiclient.errors.HttpError as exc:
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
for msg in resp.get("messages", []):
ids.append(msg["id"])
page_token = resp.get("nextPageToken")
if not page_token:
break
if not ids:
return []
logger.info("gmail: fetching %d message(s)", len(ids))
messages: list[EmailMessage] = []
for msg_id in ids:
try:
msg = user_api.messages().get(
userId="me", id=msg_id, format="full"
).execute()
headers: dict[str, str] = {
h["name"].lower(): h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
subject = headers.get("subject", "(no subject)")
sender = headers.get("from", "unknown")
date_raw = headers.get("date", "")
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
labels = msg.get("labelIds", [])
messages.append(EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
body_text=body_text,
date=date,
labels=labels,
))
except Exception as exc:
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages))
return messages

View File

@@ -0,0 +1,266 @@
"""Microsoft Graph API client for Outlook and Teams.
Adapted for Batch Agent Service: import settings from shared.config.
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timezone
from typing import Any
import httpx
from shared.config import settings
from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
_MAX_EMAILS = 200
_MAX_MESSAGES = 200
_BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str:
no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html
decoded = _html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _odata_datetime(dt: datetime) -> str:
utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
def _build_email_filter(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
clauses: list[str] = []
cfg = filter_config or {}
senders: list[str] = cfg.get("senders", [])
if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")")
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
if effective_since:
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
to_str: str | None = date_range.get("to")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
if to_dt.tzinfo is None:
to_dt = to_dt.replace(tzinfo=timezone.utc)
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
except ValueError:
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
return " and ".join(clauses)
class MSGraphClient:
def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token")
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None:
import msal
app = msal.ConfidentialClientApplication(
client_id=settings.MS_CLIENT_ID,
client_credential=settings.MS_CLIENT_SECRET,
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
)
scopes: list[str] = self._credentials_info.get("scope", "").split()
if not scopes:
scopes = ["https://graph.microsoft.com/.default"]
result = app.acquire_token_by_refresh_token(
self._refresh_token,
scopes=scopes,
)
if "access_token" not in result:
error = result.get("error_description", result.get("error", "unknown"))
raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"]
if "refresh_token" in result:
self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"]
self._credentials_info["access_token"] = self._access_token
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token}
return None
async def _get(
self,
client: httpx.AsyncClient,
url: str,
params: dict[str, Any] | None = None,
*,
retry_on_401: bool = True,
) -> dict[str, Any]:
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429:
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
resp.raise_for_status()
return resp.json()
async def fetch_emails(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = {
"$top": 50,
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
"$orderby": "receivedDateTime desc",
}
if odata_filter:
params["$filter"] = odata_filter
emails: list[EmailMessage] = []
url = f"{_GRAPH_BASE}/me/messages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(emails) < _MAX_EMAILS:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
for item in data.get("value", []):
emails.append(self._parse_email(item))
if len(emails) >= _MAX_EMAILS:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[ChatMessage]:
cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50}
if since:
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
messages: list[ChatMessage] = []
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(messages) < _MAX_MESSAGES:
try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code in (403, 404):
logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d)",
exc.response.status_code,
)
break
raise
for item in data.get("value", []):
msg = self._parse_teams_message(item)
if channel_filter and msg.channel:
if not any(c in msg.channel.lower() for c in channel_filter):
continue
messages.append(msg)
if len(messages) >= _MAX_MESSAGES:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages
@staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)"
sender_block = item.get("from", {}) or {}
sender_addr = (
(sender_block.get("emailAddress") or {}).get("address", "unknown")
)
date_str: str = item.get("receivedDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_body: str = body_block.get("content", "")
if content_type == "html":
body_text = _strip_html(raw_body)
else:
body_text = raw_body or item.get("bodyPreview", "")
body_text = body_text[:_BODY_TRUNCATE]
return EmailMessage(
id=item.get("id", ""),
subject=subject,
sender=sender_addr,
body_text=body_text,
date=date,
)
@staticmethod
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
msg_id: str = item.get("id", "")
sender_block = (item.get("from") or {}).get("user") or {}
sender: str = sender_block.get("displayName", "unknown")
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
date_str: str = item.get("createdDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_content: str = body_block.get("content", "")
content = _strip_html(raw_content) if content_type == "html" else raw_content
content = content[:_BODY_TRUNCATE]
return ChatMessage(
id=msg_id,
content=content,
sender=sender,
channel=channel,
date=date,
)

View File

@@ -0,0 +1,385 @@
"""Chatbot Journey — guided conversation to build an agent prompt_template.
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
and app.llm instead of monolith paths. Session state is in-memory (could
be moved to Redis for horizontal scaling in the future).
Journey flow:
1. Redis consumer dispatches ``journey_start`` with basic agent config.
2. Server creates an in-memory session, runs the setup LLM with
file-system tools to explore the directory, returns first question.
3. ``journey_message`` frames drive the conversation.
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.llm import get_llm
logger = logging.getLogger(__name__)
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced prompt_template.
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
_MIN_TURNS_BEFORE_NUDGE: int = 3
_MAX_TURNS: int = 15
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
# session_id → session
_sessions: dict[str, JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
return None
if s.user_id != user_id:
return None
return s
# ── System prompt builder ─────────────────────────────────────────────────
_SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their
local directory and produce a detailed prompt_template that a separate AI will use
as its instruction set.
The extraction agent already has this base behaviour built in:
- Reads each file using file-system tools.
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
- Sets isAiSuggested=1 on every new record.
- Only extracts data explicitly present in the files — it never invents information.
The user's custom prompt is appended AFTER this base behaviour, so focus on
what to look for and how to map it — not on the general extraction mechanics.
You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
The user's configured directory is: {directory}
Target data types: {data_types}
IMPORTANT — project assignment is handled automatically by the main agent runner
before the custom prompt is ever used. You MUST NOT ask the user about projects,
projectId, or how to link records to projects. Never include projectId logic or
project creation instructions in the generated prompt_template.
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content (confirmed by your exploration).
2. How fields should be mapped (e.g. filename → task title).
3. Priority or status rules (e.g. "urgent" keyword → high priority).
4. Any special handling, date extraction, or exclusions.
Once you reach 90% confidence, output the final prompt_template between these exact
markers on their own lines:
{template_start}
<the complete extraction prompt here>
{template_end}
The prompt_template must be a self-contained instruction for an AI that reads files
and must perform CRUD operations using tools to create records. It should specify:
- What entity types to create (tasks, notes, timelines) — never projects.
- How to map file content to record fields (camelCase: title, status, priority,
dueDate, content, etc.) — never include projectId.
- That isAiSuggested must be set to 1 on every new record.
- Concrete examples of mappings based on what you discovered in the directory.
{existing_section}\
Keep asking clarifying questions until you are at least 90% confident you have
enough information to generate an accurate prompt_template. Once you reach that
confidence level, stop asking and produce the final template immediately.
Begin by exploring the directory, then ask your first question.\
"""
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_template: str | None = None,
) -> str:
existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n"
if existing_template
else ""
)
return _SYSTEM_PROMPT_TEMPLATE.format(
directory=directory,
data_types=", ".join(data_types),
template_start=_TEMPLATE_START,
template_end=_TEMPLATE_END,
existing_section=existing_section,
)
# ── Template extraction ───────────────────────────────────────────────────
def _extract_template(text: str) -> str | None:
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
return None
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
end_idx = text.index(_TEMPLATE_END)
return text[start_idx:end_idx].strip() or None
# ── LLM call with tool support ───────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
messages.append(HumanMessage(content=turn["content"]))
else:
messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(_MAX_TOOL_STEPS):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"journey: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"journey: tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max tool steps.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Journey handlers (called from redis_consumer) ────────────────────────
async def handle_journey_start(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_start`` request.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template")
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template)
session = JourneySession(
session_id=session_id,
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
)
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info(
"journey: session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_message`` request.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
}
session.history.append({"role": "user", "content": message})
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": ai_reply})
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("journey: session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}

View File

@@ -0,0 +1,76 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Identical to services/chat/app/llm.py. Uses shared.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -0,0 +1,57 @@
"""Batch Agent Service — FastAPI application.
Owns: agent_runner (local directory + cloud connectors), journey builder,
filesystem_agent, integrations (Gmail, MS Graph).
Communicates with WS Gateway via Redis:
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
- Publishes to ws:out:{user_id} (journey replies + tool calls)
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
"""
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.redis_consumer import start_consumer
from app.routes import router
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info("batch-agent: starting Redis consumer")
task = asyncio.create_task(start_consumer())
yield
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
logger.info("batch-agent: Redis consumer stopped")
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
app.include_router(router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok", "service": "batch-agent"}

View File

@@ -0,0 +1,141 @@
"""Redis consumer for the Batch Agent Service.
Subscribes to batch:request:* (pattern) and dispatches:
- journey_start → handle_journey_start
- journey_message → handle_journey_message
- agent_trigger → run_local_agent / run_cloud_agent
Results are published back to ws:out:{user_id} via Redis.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from shared.redis import redis_client, batch_request_channel, ws_out_channel
from app.ws_context import set_current_user, clear_current_user
logger = logging.getLogger(__name__)
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
"""Publish a frame to the user's WS outbound channel."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_start request from WS Gateway."""
from app.journey import handle_journey_start
set_current_user(user_id)
try:
reply = await handle_journey_start(user_id, data)
await _publish_to_user(user_id, reply)
except Exception as exc:
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": data.get("session_id", ""),
"message": f"Journey setup failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_message from WS Gateway."""
from app.journey import handle_journey_message
set_current_user(user_id)
try:
reply = await handle_journey_message(user_id, data)
await _publish_to_user(user_id, reply)
except Exception as exc:
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": data.get("session_id", ""),
"message": f"Journey processing failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
from app.agent_runner import run_local_agent
set_current_user(user_id)
try:
await run_local_agent(user_id, data)
except Exception as exc:
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "run_complete",
"status": "error",
"run_context": data.get("run_context", {}),
})
finally:
clear_current_user()
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
"""Route a batch request to the correct handler."""
msg_type = message_data.get("type", "")
if msg_type == "journey_start":
await _handle_journey_start(user_id, message_data)
elif msg_type == "journey_message":
await _handle_journey_message(user_id, message_data)
elif msg_type == "agent_trigger":
await _handle_agent_trigger(user_id, message_data)
else:
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
async def start_consumer() -> None:
"""Subscribe to batch:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("batch:request:*")
logger.info("batch-agent: subscribed to batch:request:*")
try:
async for message in pubsub.listen():
if message["type"] != "pmessage":
continue
channel: str = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
# Extract user_id from channel: batch:request:{user_id}
parts = channel.split(":", 2)
if len(parts) < 3:
continue
user_id = parts[2]
raw = message["data"]
if isinstance(raw, bytes):
raw = raw.decode()
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("batch-agent: invalid JSON on channel %s", channel)
continue
# Dispatch in a separate task to avoid blocking the consumer
asyncio.create_task(_dispatch(user_id, data))
except asyncio.CancelledError:
logger.info("batch-agent: consumer shutting down")
finally:
await pubsub.punsubscribe("batch:request:*")

View File

@@ -0,0 +1,208 @@
"""Agent REST routes — catalog, billing checks, trigger.
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
Agent trigger dispatches via Redis to the consumer instead of spawning
an in-process background task.
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Header, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db import async_session
from shared.models import AgentRunLog
from shared.redis import redis_client, batch_request_channel
from app.agent_runner import is_agent_running
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Tier feature limits ───────────────────────────────────────────────
# Mirrors app/billing/tier_manager.py FEATURES dict.
FEATURES: dict[str, dict] = {
"free": {"batch_active": 1, "batch_runs_per_day": 3},
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
"power": {"batch_active": 20, "batch_runs_per_day": 100},
"team": {"batch_active": -1, "batch_runs_per_day": -1},
}
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
async with async_session() as db:
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog")
async def get_agent_catalog(
x_user_id: str = Header(..., alias="X-User-Id"),
) -> list[dict]:
return [
{
"type": "local_directory",
"name": "Local Directory Monitor",
"description": "Watches local directories, extracts data from files using AI",
},
{
"type": "gmail",
"name": "Gmail Connector",
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
},
{
"type": "teams",
"name": "Microsoft Teams Connector",
"description": "Monitors Teams messages, extracts action items",
},
{
"type": "outlook",
"name": "Outlook Connector",
"description": "Scans Outlook inbox, extracts tasks/notes",
},
]
# ── Can-create check ─────────────────────────────────────────────────
@router.post("/can-create")
async def can_create_agent(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
active_agents = body.get("active_agents", 0)
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or active_agents < limit
return {
"allowed": allowed,
"tier": x_user_tier,
"active_agents": active_agents,
"limit": limit,
}
# ── Trigger ──────────────────────────────────────────────────────────
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
"""Trigger a local agent run — creates run log and dispatches via Redis."""
active_agents = body.get("active_agents", 0)
_enforce_agent_limit(x_user_tier, active_agents)
await _enforce_run_frequency(x_user_tier, x_user_id)
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running.",
)
# Create run log in DB
async with async_session() as db:
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=x_user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
run_context = {
"type": "agent_batch",
"run_id": run_log_id,
"agent_id": stable_agent_id,
}
# Dispatch to the Redis consumer for processing
trigger_data = {
"type": "agent_trigger",
"directory": body.get("directory", ""),
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
"data_types": _to_data_types(body.get("what_to_extract", [])),
"file_extensions": body.get("file_extensions", []),
"prompt_template": body.get("custom_agent_prompt", ""),
"device_id": body.get("device_id", ""),
"run_context": run_context,
}
channel = batch_request_channel(x_user_id)
await redis_client.publish(channel, json.dumps(trigger_data))
return {
"id": run_log_id,
"agent_id": stable_agent_id,
"agent_type": "local",
"status": "running",
"items_processed": 0,
"items_created": 0,
"errors": [],
"started_at": _dt_ms(run_log.started_at),
"completed_at": None,
}

View File

@@ -0,0 +1,135 @@
"""WebSocket context for Batch Agent Service — Redis-based tool call round-trip.
Same pattern as services/chat/app/ws_context.py: publishes tool_call frames
to Redis ws:out:{user_id} and awaits BRPOP on tool:result:{call_id}.
Additionally provides set_client_executor / clear_client_executor stubs
for backward compatibility with the agent_runner code (which originally
used a DeviceConnectionManager callback). In the microservice world these
are no-ops — execute_on_client() always uses the Redis path.
"""
from __future__ import annotations
import json
import logging
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
from shared.redis import redis_client, tool_result_key, ws_out_channel
logger = logging.getLogger(__name__)
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
# Per-request user_id context var (set before agent run)
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
# Optional collector for debug / logging
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_current_user(user_id: str) -> None:
_current_user_id.set(user_id)
def clear_current_user() -> None:
_current_user_id.set(None)
def set_tool_result_collector(lst: list[dict]) -> None:
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
_tool_result_collector.set(None)
# ── Compatibility shims ──────────────────────────────────────────────────
# agent_runner.py originally called set_client_executor / clear_client_executor
# with a DeviceConnectionManager callback. In the microservice world the
# Redis-based execute_on_client replaces this, so these are no-ops that
# keep the agent_runner code unchanged.
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None:
"""No-op — kept for agent_runner compatibility."""
pass
def clear_client_executor() -> None:
"""No-op — kept for agent_runner compatibility."""
pass
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a tool_call to Electron via Redis and await the result.
1. Build tool_call payload
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
4. Return result dict
Raises RuntimeError if no user_id is set or if the call times out.
"""
user_id = _current_user_id.get()
if not user_id:
raise RuntimeError(
"execute_on_client() called without a user_id — "
"set_current_user() must be called first."
)
call_id = str(uuid4())
payload: dict[str, Any] = {
"type": "tool_call",
"id": call_id,
"action": action,
}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
# Publish tool_call to WS Gateway → Electron
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
# Wait for Electron's tool_result
result_key = tool_result_key(call_id)
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
if response is None:
raise RuntimeError(
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
f"device may be offline or unresponsive."
)
# response is (key, value) tuple
_, raw = response
result = json.loads(raw)
# Collect for debug if requested
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -0,0 +1,20 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
redis>=5.0.0
cryptography>=42.0.0
python-dotenv>=1.0.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
croniter>=2.0.0
google-api-python-client>=2.130.0
google-auth>=2.30.0
msal>=1.28.0