feat(batch-agent): extract Batch Agent Service (Step 3)
- agent_runner: local directory + cloud agent orchestration via Redis - 5 domain agents: filesystem, task, note, project, timeline - integrations: Gmail, MS Graph (Outlook + Teams) - journey: guided chatbot conversation to build prompt_template - routes: REST endpoints (catalog, can-create, trigger) - redis_consumer: subscribes to batch:request:* pattern - ws_context: Redis-based execute_on_client for tool round-trip - Dockerfile with 300s timeout for long-running batch jobs
This commit is contained in:
36
services/batch-agent/Dockerfile
Normal file
36
services/batch-agent/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/batch-agent/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/batch-agent/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "300"]
|
||||||
884
services/batch-agent/app/agent_runner.py
Normal file
884
services/batch-agent/app/agent_runner.py
Normal file
@@ -0,0 +1,884 @@
|
|||||||
|
"""Agent run orchestrator — adapted for Batch Agent Service.
|
||||||
|
|
||||||
|
Key changes from monolith app/core/agent_runner.py:
|
||||||
|
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
||||||
|
- set_current_user / clear_current_user replace set_client_executor.
|
||||||
|
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
||||||
|
of SQLAlchemy model objects.
|
||||||
|
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
||||||
|
- Cloud agent import path changed to app.integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from app.llm import get_llm
|
||||||
|
from app.ws_context import execute_on_client, set_current_user, clear_current_user
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from shared.redis import redis_client, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Concurrency guard ─────────────────────────────────────────────────────
|
||||||
|
_running_agents: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def is_agent_running(agent_id: str) -> bool:
|
||||||
|
return agent_id in _running_agents
|
||||||
|
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
_TOOL_CALL_TIMEOUT: int = 30
|
||||||
|
_MAX_PROCESSING_STEPS: int = 12
|
||||||
|
_MAX_SCAN_DEPTH: int = 5
|
||||||
|
|
||||||
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
|
"tasks": TASK_TOOLS,
|
||||||
|
"notes": NOTE_TOOLS,
|
||||||
|
"timelines": TIMELINE_TOOLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
||||||
|
|
||||||
|
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
||||||
|
"assigned to someone, or tracked with a due date or status."
|
||||||
|
),
|
||||||
|
"notes": (
|
||||||
|
"Documentation, meeting notes, summaries, reference material — "
|
||||||
|
"written content meant to be read and referenced rather than acted on."
|
||||||
|
),
|
||||||
|
"timelines": (
|
||||||
|
"Project milestones, deadlines, scheduled events — "
|
||||||
|
"specific dates that mark a point in the progress of a project."
|
||||||
|
),
|
||||||
|
"projects": (
|
||||||
|
"High-level project entities — only relevant if the file clearly introduces "
|
||||||
|
"a new project or updates the scope of an existing one."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_STEP1_SYSTEM_PROMPT = """\
|
||||||
|
You are a file classifier for a freelance project management tool.
|
||||||
|
|
||||||
|
Your job is to match a file to an existing project and identify which data domains to extract.
|
||||||
|
|
||||||
|
## Project matching rules (STRICT — follow in order)
|
||||||
|
|
||||||
|
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
||||||
|
that overlaps with the existing projects listed below.
|
||||||
|
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
||||||
|
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
||||||
|
when the file has zero meaningful connection to any listed project.
|
||||||
|
4. When in doubt, pick the closest match from the list.
|
||||||
|
|
||||||
|
## Response format
|
||||||
|
|
||||||
|
Respond ONLY with a JSON object — no markdown, no explanation:
|
||||||
|
|
||||||
|
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
||||||
|
|
||||||
|
## Domain definitions (only consider domains in the allowed list)
|
||||||
|
|
||||||
|
{domain_definitions}
|
||||||
|
|
||||||
|
## Existing projects
|
||||||
|
|
||||||
|
{projects_list}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PROCESSING_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
|
||||||
|
Your task: extract structured data from the file content and persist it using the available tools.
|
||||||
|
|
||||||
|
## Mandatory process — follow this order for EVERY item you extract
|
||||||
|
|
||||||
|
1. READ the existing records listed below for the relevant domain.
|
||||||
|
2. SEARCH for a match by title, topic, or semantic similarity.
|
||||||
|
3. If a match exists → call the update_* tool with the existing record's id.
|
||||||
|
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
||||||
|
|
||||||
|
NEVER call create_* without first checking the existing records.
|
||||||
|
NEVER duplicate a record that already exists under a different wording.
|
||||||
|
|
||||||
|
## Existing records (source of truth)
|
||||||
|
|
||||||
|
{existing_context}
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Project: {project_context}
|
||||||
|
Domains to extract: {data_types}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Cloud processing prompt ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_PROCESSING_PROMPT = """\
|
||||||
|
You are a data extraction and management assistant for a freelance project
|
||||||
|
management tool.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||||
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||||
|
Notes : list_notes, get_note, create_note, update_note
|
||||||
|
Timelines : list_timelines, create_timeline, update_timeline
|
||||||
|
Projects : list_all_projects, get_project, create_project, update_project
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
1. Read the full content of each file below using read_file_content.
|
||||||
|
2. For each piece of information found, ALWAYS try to match and update an
|
||||||
|
existing record before creating a new one.
|
||||||
|
3. ONLY act on these entity types: {data_types}.
|
||||||
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||||
|
5. If a file contains no relevant data for the target entity types, skip it.
|
||||||
|
|
||||||
|
{project_context}
|
||||||
|
|
||||||
|
Files to process:
|
||||||
|
{file_list}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
After processing all files, respond with a brief summary of what you updated
|
||||||
|
and what you created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_with_tools(
|
||||||
|
*,
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
tools: list[Any],
|
||||||
|
max_steps: int,
|
||||||
|
) -> str:
|
||||||
|
"""Run an LLM agent with tool-calling, returning the final text response."""
|
||||||
|
llm = get_llm()
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_message),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:200],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||||
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||||
|
for dt in data_types:
|
||||||
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||||
|
if dt_tools:
|
||||||
|
tools.extend(dt_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based directory scanner ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_directories(
|
||||||
|
paths: list[str],
|
||||||
|
extensions: list[str],
|
||||||
|
last_run_at: datetime | None,
|
||||||
|
) -> list[str]:
|
||||||
|
all_files: list[str] = []
|
||||||
|
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
||||||
|
|
||||||
|
async def _walk(path: str, depth: int) -> None:
|
||||||
|
if depth > _MAX_SCAN_DEPTH:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="list_directory", data={"path": path})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
||||||
|
return
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
if not entry_path:
|
||||||
|
continue
|
||||||
|
if entry.get("type") == "directory":
|
||||||
|
await _walk(entry_path, depth + 1)
|
||||||
|
elif entry.get("type") == "file":
|
||||||
|
if ext_set:
|
||||||
|
dot_pos = entry_path.rfind(".")
|
||||||
|
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
||||||
|
if file_ext not in ext_set:
|
||||||
|
continue
|
||||||
|
all_files.append(entry_path)
|
||||||
|
|
||||||
|
for root in paths:
|
||||||
|
await _walk(root, depth=0)
|
||||||
|
|
||||||
|
if last_run_at is None:
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
last_run_ms = int(last_run_at.timestamp() * 1000)
|
||||||
|
filtered: list[str] = []
|
||||||
|
for file_path in all_files:
|
||||||
|
try:
|
||||||
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||||
|
modified_at = meta.get("modifiedAt")
|
||||||
|
if modified_at is None:
|
||||||
|
filtered.append(file_path)
|
||||||
|
continue
|
||||||
|
if isinstance(modified_at, (int, float)):
|
||||||
|
mod_ms = int(modified_at)
|
||||||
|
else:
|
||||||
|
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
||||||
|
if mod_ms > last_run_ms:
|
||||||
|
filtered.append(file_path)
|
||||||
|
except Exception:
|
||||||
|
filtered.append(file_path)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based entity fetchers ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_projects() -> list[dict]:
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
_DOMAIN_TABLE: dict[str, str] = {
|
||||||
|
"tasks": "tasks",
|
||||||
|
"notes": "notes",
|
||||||
|
"timelines": "timelines",
|
||||||
|
"projects": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
||||||
|
table = _DOMAIN_TABLE.get(domain)
|
||||||
|
if not table:
|
||||||
|
return []
|
||||||
|
filters: dict[str, Any] = {}
|
||||||
|
if project_id != "standalone" and domain != "projects":
|
||||||
|
filters["projectId"] = project_id
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table=table,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
)
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
||||||
|
if not rows:
|
||||||
|
return f"No existing {domain}."
|
||||||
|
lines: list[str] = []
|
||||||
|
for r in rows:
|
||||||
|
if domain == "tasks":
|
||||||
|
desc = r.get("description") or ""
|
||||||
|
desc_part = f" — {desc[:120]}" if desc else ""
|
||||||
|
assignee = r.get("assignee") or r.get("assignees") or ""
|
||||||
|
due = r.get("dueDate") or r.get("due_date") or ""
|
||||||
|
meta = ", ".join(filter(None, [
|
||||||
|
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
||||||
|
f"assignee: {assignee}" if assignee else "",
|
||||||
|
f"due: {due}" if due else "",
|
||||||
|
]))
|
||||||
|
lines.append(
|
||||||
|
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
||||||
|
f" ({meta}, id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "notes":
|
||||||
|
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
||||||
|
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
||||||
|
)
|
||||||
|
elif domain == "timelines":
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "projects":
|
||||||
|
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
||||||
|
summary_part = f" — {summary}" if summary else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
||||||
|
f" (id: {r['id']})"
|
||||||
|
)
|
||||||
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _classify_file(
|
||||||
|
file_path: str,
|
||||||
|
file_content: str,
|
||||||
|
projects: list[dict],
|
||||||
|
config_data_types: list[str],
|
||||||
|
) -> tuple[str, list[str], str | None]:
|
||||||
|
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
||||||
|
|
||||||
|
if not file_content.strip():
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
valid_project_ids = {p["id"] for p in projects}
|
||||||
|
|
||||||
|
def _fmt_project(p: dict) -> str:
|
||||||
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||||
|
summary_part = f" — {summary[:100]}" if summary else ""
|
||||||
|
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
||||||
|
|
||||||
|
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
||||||
|
|
||||||
|
domain_definitions = "\n".join(
|
||||||
|
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
||||||
|
for d in config_data_types
|
||||||
|
if d in _DOMAIN_DESCRIPTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
system = _STEP1_SYSTEM_PROMPT.format(
|
||||||
|
domain_definitions=domain_definitions,
|
||||||
|
projects_list=projects_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system),
|
||||||
|
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
||||||
|
])
|
||||||
|
raw = _as_text(response.content).strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
raw_project_id: str = str(parsed.get("project_id") or "new")
|
||||||
|
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
||||||
|
new_project_name: str | None = (
|
||||||
|
str(parsed["new_project_name"]).strip() or None
|
||||||
|
if project_id == "new" and parsed.get("new_project_name")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
domains: list[str] = [
|
||||||
|
d for d in parsed.get("domains", [])
|
||||||
|
if d in config_data_types
|
||||||
|
]
|
||||||
|
if not domains:
|
||||||
|
domains = list(config_data_types)
|
||||||
|
return project_id, domains, new_project_name
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
||||||
|
)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner (two-step per file) ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(user_id: str, trigger_data: dict[str, Any]) -> None:
|
||||||
|
"""Execute a local directory agent run.
|
||||||
|
|
||||||
|
In the microservice world, trigger_data is a serialized dict from
|
||||||
|
the REST route (forwarded via Redis), containing the agent config
|
||||||
|
fields and run_context.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
run_context: dict = trigger_data.get("run_context", {})
|
||||||
|
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
||||||
|
run_id = run_context.get("run_id")
|
||||||
|
|
||||||
|
_running_agents.add(agent_id)
|
||||||
|
|
||||||
|
# Extract config from trigger payload
|
||||||
|
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
||||||
|
if not directory_paths:
|
||||||
|
directory = trigger_data.get("directory", "")
|
||||||
|
if directory:
|
||||||
|
directory_paths = [directory]
|
||||||
|
|
||||||
|
data_types: list[str] = trigger_data.get("data_types", [])
|
||||||
|
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
||||||
|
prompt_template: str = trigger_data.get("prompt_template", "")
|
||||||
|
last_run_at_raw = trigger_data.get("last_run_at")
|
||||||
|
last_run_at: datetime | None = None
|
||||||
|
if last_run_at_raw:
|
||||||
|
if isinstance(last_run_at_raw, str):
|
||||||
|
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
||||||
|
elif isinstance(last_run_at_raw, (int, float)):
|
||||||
|
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{prompt_template}"
|
||||||
|
if prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or load run log
|
||||||
|
run_log_id = run_id
|
||||||
|
if not run_log_id:
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── Scan directories ─────────────────────────────────────────
|
||||||
|
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
||||||
|
file_paths = await _scan_directories(
|
||||||
|
paths=directory_paths,
|
||||||
|
extensions=file_extensions,
|
||||||
|
last_run_at=last_run_at,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch all projects once ──────────────────────────────────
|
||||||
|
projects = await _fetch_projects()
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
try:
|
||||||
|
file_result = await execute_on_client(
|
||||||
|
action="read_file_content", data={"path": file_path}
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
# Step 1 — classify file
|
||||||
|
project_id, domains, new_project_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=file_content,
|
||||||
|
projects=projects,
|
||||||
|
config_data_types=data_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2 — resolve project_id, fetch entities, process
|
||||||
|
if project_id == "new":
|
||||||
|
proj_name = new_project_name or "Untitled Project"
|
||||||
|
try:
|
||||||
|
proj_result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": proj_name, "clientId": None},
|
||||||
|
)
|
||||||
|
created = proj_result.get("row", {})
|
||||||
|
effective_project_id = created.get("id", "standalone")
|
||||||
|
if "id" in created:
|
||||||
|
projects.append(created)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
||||||
|
effective_project_id = "standalone"
|
||||||
|
proj_name = "unknown"
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {effective_project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_project_id = project_id
|
||||||
|
proj = next((p for p in projects if p["id"] == project_id), None)
|
||||||
|
proj_name = proj.get("name", project_id) if proj else project_id
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
|
||||||
|
domains = [d for d in domains if d != "projects"]
|
||||||
|
|
||||||
|
existing_blocks: list[str] = []
|
||||||
|
for domain in domains:
|
||||||
|
rows = await _fetch_domain_entities(domain, effective_project_id)
|
||||||
|
existing_blocks.append(_format_entities_for_context(domain, rows))
|
||||||
|
|
||||||
|
existing_context = "\n\n".join(existing_blocks)
|
||||||
|
|
||||||
|
system_prompt = _PROCESSING_SYSTEM_PROMPT.format(
|
||||||
|
existing_context=existing_context,
|
||||||
|
project_context=project_context,
|
||||||
|
data_types=", ".join(domains),
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
processing_tools = _build_processing_tools(domains)
|
||||||
|
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=(
|
||||||
|
f"Process this file and extract relevant information.\n\n"
|
||||||
|
f"File: {file_path}\n\nContent:\n{file_content}"
|
||||||
|
),
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s file=%r result=%s",
|
||||||
|
run_log_id, file_path, result_text[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Error processing '{file_path}': {exc}")
|
||||||
|
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
||||||
|
finally:
|
||||||
|
_running_agents.discard(agent_id)
|
||||||
|
|
||||||
|
# ── Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify Electron that the run is complete via Redis
|
||||||
|
if run_context:
|
||||||
|
try:
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps({
|
||||||
|
"type": "run_complete",
|
||||||
|
"run_context": run_context,
|
||||||
|
"status": final_status,
|
||||||
|
}))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(user_id: str, config_id: str) -> None:
|
||||||
|
"""Execute a cloud connector agent run.
|
||||||
|
|
||||||
|
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
||||||
|
messages from the provider, and runs LLM extraction.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one_or_none()
|
||||||
|
if config is None:
|
||||||
|
logger.error("agent_runner: cloud config %s not found", config_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create run log
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="cloud",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
# ── Decrypt OAuth token ────────────────────────────────────────
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Instantiate provider ──────────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch messages ────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
||||||
|
config.id, len(raw_messages), config.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Extract + insert via LLM ─────────────────────────────────
|
||||||
|
try:
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{config.prompt_template}"
|
||||||
|
if config.prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
processing_prompt = _CLOUD_PROCESSING_PROMPT.format(
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
project_context="Determine the appropriate project from the message context.",
|
||||||
|
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _run_agent_with_tools(
|
||||||
|
system_prompt=processing_prompt,
|
||||||
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
|
||||||
|
# ── Persist refreshed token ───────────────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
||||||
|
|
||||||
|
# ── Finalise ──────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=0,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log_id: int | str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update last_run_at on the config."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
||||||
|
)
|
||||||
|
managed = result.scalar_one_or_none()
|
||||||
|
if managed is None:
|
||||||
|
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id:
|
||||||
|
if config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|
||||||
1
services/batch-agent/app/agents/__init__.py
Normal file
1
services/batch-agent/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Batch Agent Service domain agents and filesystem tools."""
|
||||||
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
110
services/batch-agent/app/agents/note_agent.py
Normal file
110
services/batch-agent/app/agents/note_agent.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Note agent — Markdown note management.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context and app.llm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.llm import embed
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="notes",
|
||||||
|
filters={"projectId": normalized_project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No notes found."
|
||||||
|
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_note(note_id: str) -> str:
|
||||||
|
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||||
|
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||||
|
row = result.get("row")
|
||||||
|
if not row:
|
||||||
|
return f"Note {note_id} not found."
|
||||||
|
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_note(title: str, content: str, project_id: str = "") -> str:
|
||||||
|
"""Create a new note."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="notes",
|
||||||
|
data={
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_note(note_id: str, title: str = "", content: str = "") -> str:
|
||||||
|
"""Update an existing note. Only pass fields that should change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if content:
|
||||||
|
updates["content"] = content
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="notes",
|
||||||
|
data={"id": note_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
if content:
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_note(note_id: str) -> str:
|
||||||
|
"""Delete a note permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||||
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
NOTE_TOOLS: list[Any] = [
|
||||||
|
list_notes,
|
||||||
|
get_note,
|
||||||
|
create_note,
|
||||||
|
update_note,
|
||||||
|
delete_note,
|
||||||
|
]
|
||||||
110
services/batch-agent/app/agents/project_agent.py
Normal file
110
services/batch-agent/app/agents/project_agent.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Project agent — full lifecycle management.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_projects(client_id: str = "", include_archived: int = 0) -> str:
|
||||||
|
"""List projects, optionally filtered by client_id."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="projects",
|
||||||
|
filters={
|
||||||
|
"clientId": client_id or None,
|
||||||
|
"includeArchived": bool(include_archived),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_all_projects() -> str:
|
||||||
|
"""List every project regardless of client or status."""
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_project(project_id: str) -> str:
|
||||||
|
"""Fetch a single project by its UUID."""
|
||||||
|
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||||
|
row = result.get("row")
|
||||||
|
if not row:
|
||||||
|
return f"Project {project_id} not found."
|
||||||
|
return (
|
||||||
|
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||||
|
f"clientId: {row.get('clientId', 'none')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_project(name: str, client_id: str = "") -> str:
|
||||||
|
"""Create a new project."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": name, "clientId": client_id or None},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_project(
|
||||||
|
project_id: str,
|
||||||
|
name: str = "",
|
||||||
|
client_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
ai_summary: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update a project. Only pass fields that should change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if name:
|
||||||
|
updates["name"] = name
|
||||||
|
if client_id:
|
||||||
|
updates["clientId"] = client_id
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if ai_summary:
|
||||||
|
updates["aiSummary"] = ai_summary
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="projects",
|
||||||
|
data={"id": project_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_project(project_id: str) -> str:
|
||||||
|
"""Permanently delete a project."""
|
||||||
|
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||||
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_TOOLS: list[Any] = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
197
services/batch-agent/app/agents/task_agent.py
Normal file
197
services/batch-agent/app/agents/task_agent.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Task agent — full CRUD for tasks and task comments.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks(
|
||||||
|
project_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
search: str = "",
|
||||||
|
order_by: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""List tasks, optionally filtered by project_id, status, search, or order_by."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={
|
||||||
|
"projectId": normalized_project_id or None,
|
||||||
|
"status": status or None,
|
||||||
|
"search": search or None,
|
||||||
|
"orderBy": order_by or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks found matching the given filters."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_task(
|
||||||
|
title: str,
|
||||||
|
description: str = "",
|
||||||
|
status: str = "todo",
|
||||||
|
priority: str = "medium",
|
||||||
|
assignees: str = "[]",
|
||||||
|
due_date: int = 0,
|
||||||
|
project_id: str = "",
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new task."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="tasks",
|
||||||
|
data={
|
||||||
|
"title": title,
|
||||||
|
"description": description or None,
|
||||||
|
"status": status,
|
||||||
|
"priority": priority,
|
||||||
|
"assignee": assignees,
|
||||||
|
"dueDate": due_date or None,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return (
|
||||||
|
f"Task created: '{row['title']}' "
|
||||||
|
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_task(
|
||||||
|
task_id: str,
|
||||||
|
title: str = "",
|
||||||
|
description: str = "",
|
||||||
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignees: str = "",
|
||||||
|
due_date: int = -1,
|
||||||
|
project_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update fields on an existing task. Only pass fields you want to change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if description:
|
||||||
|
updates["description"] = description
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if priority:
|
||||||
|
updates["priority"] = priority
|
||||||
|
if assignees:
|
||||||
|
updates["assignee"] = assignees
|
||||||
|
if due_date != -1:
|
||||||
|
updates["dueDate"] = due_date or None
|
||||||
|
if project_id:
|
||||||
|
updates["projectId"] = project_id
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="tasks",
|
||||||
|
data={"id": task_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task(task_id: str) -> str:
|
||||||
|
"""Delete a task permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||||
|
return f"Task {task_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks_due_today() -> str:
|
||||||
|
"""List all tasks whose due date falls on today's date."""
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
|
end_ms = start_ms + 86_400_000 - 1
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks are due today."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_task_comments(task_id: str) -> str:
|
||||||
|
"""List all comments on a task by its UUID."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="taskComments",
|
||||||
|
filters={"taskId": task_id},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return f"No comments found for task {task_id}."
|
||||||
|
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||||
|
"""Add a comment to a task."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="taskComments",
|
||||||
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
|
)
|
||||||
|
row = result.get("row", {})
|
||||||
|
row_author = row.get("author", author)
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task_comment(comment_id: str) -> str:
|
||||||
|
"""Delete a task comment by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||||
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TOOLS: list[Any] = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
88
services/batch-agent/app/agents/timeline_agent.py
Normal file
88
services/batch-agent/app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Timeline agent — project milestone management.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters={"projectId": normalized_project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timelines found."
|
||||||
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_timeline(
|
||||||
|
project_id: str, title: str, date: int, is_ai_suggested: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project timeline (milestone)."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="timelines",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_timeline(timeline_id: str, title: str = "", date: int = -1) -> str:
|
||||||
|
"""Update a timeline. Only pass fields that should change."""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="timelines",
|
||||||
|
data={"id": timeline_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
|
"""Delete a timeline permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
|
list_timelines,
|
||||||
|
create_timeline,
|
||||||
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
]
|
||||||
108
services/batch-agent/app/integrations/__init__.py
Normal file
108
services/batch-agent/app/integrations/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from shared.config instead of app.config.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (EmailMessage, ChatMessage)
|
||||||
|
* get_provider() — factory for Gmail/MS Graph clients
|
||||||
|
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
252
services/batch-agent/app/integrations/gmail.py
Normal file
252
services/batch-agent/app/integrations/gmail.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.integrations instead of
|
||||||
|
app.integrations (same relative path within the service).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users()
|
||||||
|
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import settings from shared.config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d)",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
385
services/batch-agent/app/journey.py
Normal file
385
services/batch-agent/app/journey.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Chatbot Journey — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
|
||||||
|
and app.llm instead of monolith paths. Session state is in-memory (could
|
||||||
|
be moved to Redis for horizontal scaling in the future).
|
||||||
|
|
||||||
|
Journey flow:
|
||||||
|
1. Redis consumer dispatches ``journey_start`` with basic agent config.
|
||||||
|
2. Server creates an in-memory session, runs the setup LLM with
|
||||||
|
file-system tools to explore the directory, returns first question.
|
||||||
|
3. ``journey_message`` frames drive the conversation.
|
||||||
|
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
|
||||||
|
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from app.llm import get_llm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneySession:
|
||||||
|
session_id: str
|
||||||
|
user_id: str
|
||||||
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# session_id → session
|
||||||
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
|
s = _sessions.get(session_id)
|
||||||
|
if s is None or s.is_expired():
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
return None
|
||||||
|
if s.user_id != user_id:
|
||||||
|
return None
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
|
local directory and produce a detailed prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
|
The extraction agent already has this base behaviour built in:
|
||||||
|
- Reads each file using file-system tools.
|
||||||
|
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||||
|
- Sets isAiSuggested=1 on every new record.
|
||||||
|
- Only extracts data explicitly present in the files — it never invents information.
|
||||||
|
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||||
|
what to look for and how to map it — not on the general extraction mechanics.
|
||||||
|
|
||||||
|
You have access to file-system tools to explore the user's directory:
|
||||||
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
IMPORTANT — project assignment is handled automatically by the main agent runner
|
||||||
|
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
||||||
|
projectId, or how to link records to projects. Never include projectId logic or
|
||||||
|
project creation instructions in the generated prompt_template.
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content (confirmed by your exploration).
|
||||||
|
2. How fields should be mapped (e.g. filename → task title).
|
||||||
|
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
4. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
Once you reach 90% confidence, output the final prompt_template between these exact
|
||||||
|
markers on their own lines:
|
||||||
|
|
||||||
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
|
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||||
|
and must perform CRUD operations using tools to create records. It should specify:
|
||||||
|
- What entity types to create (tasks, notes, timelines) — never projects.
|
||||||
|
- How to map file content to record fields (camelCase: title, status, priority,
|
||||||
|
dueDate, content, etc.) — never include projectId.
|
||||||
|
- That isAiSuggested must be set to 1 on every new record.
|
||||||
|
- Concrete examples of mappings based on what you discovered in the directory.
|
||||||
|
|
||||||
|
{existing_section}\
|
||||||
|
Keep asking clarifying questions until you are at least 90% confident you have
|
||||||
|
enough information to generate an accurate prompt_template. Once you reach that
|
||||||
|
confidence level, stop asking and produce the final template immediately.
|
||||||
|
Begin by exploring the directory, then ask your first question.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_prompt(
|
||||||
|
directory: str,
|
||||||
|
data_types: list[str],
|
||||||
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
existing_section = (
|
||||||
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
|
f"---\n{existing_template}\n---\n"
|
||||||
|
if existing_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
|
template_start=_TEMPLATE_START,
|
||||||
|
template_end=_TEMPLATE_END,
|
||||||
|
existing_section=existing_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_template(text: str) -> str | None:
|
||||||
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
|
for turn in history:
|
||||||
|
if turn["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=turn["content"]))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max tool steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey handlers (called from redis_consumer) ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_start(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_start`` request.
|
||||||
|
|
||||||
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
|
and returns the ``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
agent_type = frame.get("agent_type", "local")
|
||||||
|
directory = frame.get("directory", "")
|
||||||
|
data_types = frame.get("data_types", [])
|
||||||
|
existing_template = frame.get("existing_template")
|
||||||
|
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
|
session = JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
directory=directory,
|
||||||
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` request.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
if not done:
|
||||||
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
nudge_content = (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("journey: session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
76
services/batch-agent/app/llm.py
Normal file
76
services/batch-agent/app/llm.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Identical to services/chat/app/llm.py. Uses shared.config.settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
return None
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_router_llm(
|
||||||
|
*,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
57
services/batch-agent/app/main.py
Normal file
57
services/batch-agent/app/main.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Batch Agent Service — FastAPI application.
|
||||||
|
|
||||||
|
Owns: agent_runner (local directory + cloud connectors), journey builder,
|
||||||
|
filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
Communicates with WS Gateway via Redis:
|
||||||
|
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
|
||||||
|
- Publishes to ws:out:{user_id} (journey replies + tool calls)
|
||||||
|
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
|
||||||
|
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.redis_consumer import start_consumer
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
logger.info("batch-agent: starting Redis consumer")
|
||||||
|
task = asyncio.create_task(start_consumer())
|
||||||
|
yield
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("batch-agent: Redis consumer stopped")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "service": "batch-agent"}
|
||||||
141
services/batch-agent/app/redis_consumer.py
Normal file
141
services/batch-agent/app/redis_consumer.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Redis consumer for the Batch Agent Service.
|
||||||
|
|
||||||
|
Subscribes to batch:request:* (pattern) and dispatches:
|
||||||
|
- journey_start → handle_journey_start
|
||||||
|
- journey_message → handle_journey_message
|
||||||
|
- agent_trigger → run_local_agent / run_cloud_agent
|
||||||
|
|
||||||
|
Results are published back to ws:out:{user_id} via Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
||||||
|
|
||||||
|
from app.ws_context import set_current_user, clear_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
"""Publish a frame to the user's WS outbound channel."""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_start request from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_start
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, data)
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": data.get("session_id", ""),
|
||||||
|
"message": f"Journey setup failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_message from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_message
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, data)
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": data.get("session_id", ""),
|
||||||
|
"message": f"Journey processing failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
||||||
|
from app.agent_runner import run_local_agent
|
||||||
|
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
await run_local_agent(user_id, data)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "run_complete",
|
||||||
|
"status": "error",
|
||||||
|
"run_context": data.get("run_context", {}),
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
||||||
|
"""Route a batch request to the correct handler."""
|
||||||
|
msg_type = message_data.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "journey_start":
|
||||||
|
await _handle_journey_start(user_id, message_data)
|
||||||
|
elif msg_type == "journey_message":
|
||||||
|
await _handle_journey_message(user_id, message_data)
|
||||||
|
elif msg_type == "agent_trigger":
|
||||||
|
await _handle_agent_trigger(user_id, message_data)
|
||||||
|
else:
|
||||||
|
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_consumer() -> None:
|
||||||
|
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.psubscribe("batch:request:*")
|
||||||
|
logger.info("batch-agent: subscribed to batch:request:*")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] != "pmessage":
|
||||||
|
continue
|
||||||
|
|
||||||
|
channel: str = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
# Extract user_id from channel: batch:request:{user_id}
|
||||||
|
parts = channel.split(":", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
continue
|
||||||
|
user_id = parts[2]
|
||||||
|
|
||||||
|
raw = message["data"]
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
raw = raw.decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dispatch in a separate task to avoid blocking the consumer
|
||||||
|
asyncio.create_task(_dispatch(user_id, data))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("batch-agent: consumer shutting down")
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe("batch:request:*")
|
||||||
208
services/batch-agent/app/routes.py
Normal file
208
services/batch-agent/app/routes.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Agent REST routes — catalog, billing checks, trigger.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
||||||
|
Agent trigger dispatches via Redis to the consumer instead of spawning
|
||||||
|
an in-process background task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog
|
||||||
|
from shared.redis import redis_client, batch_request_channel
|
||||||
|
|
||||||
|
from app.agent_runner import is_agent_running
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
# ── Tier feature limits ───────────────────────────────────────────────
|
||||||
|
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
||||||
|
FEATURES: dict[str, dict] = {
|
||||||
|
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
||||||
|
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
||||||
|
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
||||||
|
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
|
"task": "tasks", "tasks": "tasks",
|
||||||
|
"note": "notes", "notes": "notes",
|
||||||
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
|
"project": "projects", "projects": "projects",
|
||||||
|
}
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
for v in values:
|
||||||
|
mapped = normalize.get(v)
|
||||||
|
if mapped and mapped not in seen:
|
||||||
|
seen.add(mapped)
|
||||||
|
result.append(mapped)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return
|
||||||
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog")
|
||||||
|
async def get_agent_catalog(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "local_directory",
|
||||||
|
"name": "Local Directory Monitor",
|
||||||
|
"description": "Watches local directories, extracts data from files using AI",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "gmail",
|
||||||
|
"name": "Gmail Connector",
|
||||||
|
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "teams",
|
||||||
|
"name": "Microsoft Teams Connector",
|
||||||
|
"description": "Monitors Teams messages, extracts action items",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "outlook",
|
||||||
|
"name": "Outlook Connector",
|
||||||
|
"description": "Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Can-create check ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/can-create")
|
||||||
|
async def can_create_agent(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
||||||
|
allowed = limit == -1 or active_agents < limit
|
||||||
|
return {
|
||||||
|
"allowed": allowed,
|
||||||
|
"tier": x_user_tier,
|
||||||
|
"active_agents": active_agents,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trigger ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
_enforce_agent_limit(x_user_tier, active_agents)
|
||||||
|
await _enforce_run_frequency(x_user_tier, x_user_id)
|
||||||
|
|
||||||
|
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
||||||
|
|
||||||
|
if is_agent_running(stable_agent_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Agent is already running.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create run log in DB
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=stable_agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=x_user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
run_context = {
|
||||||
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch to the Redis consumer for processing
|
||||||
|
trigger_data = {
|
||||||
|
"type": "agent_trigger",
|
||||||
|
"directory": body.get("directory", ""),
|
||||||
|
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
||||||
|
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
||||||
|
"file_extensions": body.get("file_extensions", []),
|
||||||
|
"prompt_template": body.get("custom_agent_prompt", ""),
|
||||||
|
"device_id": body.get("device_id", ""),
|
||||||
|
"run_context": run_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = batch_request_channel(x_user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(trigger_data))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
"agent_type": "local",
|
||||||
|
"status": "running",
|
||||||
|
"items_processed": 0,
|
||||||
|
"items_created": 0,
|
||||||
|
"errors": [],
|
||||||
|
"started_at": _dt_ms(run_log.started_at),
|
||||||
|
"completed_at": None,
|
||||||
|
}
|
||||||
135
services/batch-agent/app/ws_context.py
Normal file
135
services/batch-agent/app/ws_context.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""WebSocket context for Batch Agent Service — Redis-based tool call round-trip.
|
||||||
|
|
||||||
|
Same pattern as services/chat/app/ws_context.py: publishes tool_call frames
|
||||||
|
to Redis ws:out:{user_id} and awaits BRPOP on tool:result:{call_id}.
|
||||||
|
|
||||||
|
Additionally provides set_client_executor / clear_client_executor stubs
|
||||||
|
for backward compatibility with the agent_runner code (which originally
|
||||||
|
used a DeviceConnectionManager callback). In the microservice world these
|
||||||
|
are no-ops — execute_on_client() always uses the Redis path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from shared.redis import redis_client, tool_result_key, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
|
||||||
|
|
||||||
|
# Per-request user_id context var (set before agent run)
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
|
||||||
|
|
||||||
|
# Optional collector for debug / logging
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_user(user_id: str) -> None:
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_current_user() -> None:
|
||||||
|
_current_user_id.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Compatibility shims ──────────────────────────────────────────────────
|
||||||
|
# agent_runner.py originally called set_client_executor / clear_client_executor
|
||||||
|
# with a DeviceConnectionManager callback. In the microservice world the
|
||||||
|
# Redis-based execute_on_client replaces this, so these are no-ops that
|
||||||
|
# keep the agent_runner code unchanged.
|
||||||
|
|
||||||
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]] | None) -> None:
|
||||||
|
"""No-op — kept for agent_runner compatibility."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def clear_client_executor() -> None:
|
||||||
|
"""No-op — kept for agent_runner compatibility."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_on_client(
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a tool_call to Electron via Redis and await the result.
|
||||||
|
|
||||||
|
1. Build tool_call payload
|
||||||
|
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
|
||||||
|
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
|
||||||
|
4. Return result dict
|
||||||
|
|
||||||
|
Raises RuntimeError if no user_id is set or if the call times out.
|
||||||
|
"""
|
||||||
|
user_id = _current_user_id.get()
|
||||||
|
if not user_id:
|
||||||
|
raise RuntimeError(
|
||||||
|
"execute_on_client() called without a user_id — "
|
||||||
|
"set_current_user() must be called first."
|
||||||
|
)
|
||||||
|
|
||||||
|
call_id = str(uuid4())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": call_id,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
if table is not None:
|
||||||
|
payload["table"] = table
|
||||||
|
if data is not None:
|
||||||
|
payload["data"] = data
|
||||||
|
if filters is not None:
|
||||||
|
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
if vector is not None:
|
||||||
|
payload["vector"] = vector
|
||||||
|
if limit is not None:
|
||||||
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
# Publish tool_call to WS Gateway → Electron
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
# Wait for Electron's tool_result
|
||||||
|
result_key = tool_result_key(call_id)
|
||||||
|
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
|
||||||
|
f"device may be offline or unresponsive."
|
||||||
|
)
|
||||||
|
|
||||||
|
# response is (key, value) tuple
|
||||||
|
_, raw = response
|
||||||
|
result = json.loads(raw)
|
||||||
|
|
||||||
|
# Collect for debug if requested
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None:
|
||||||
|
collector.append({
|
||||||
|
"action": action,
|
||||||
|
"table": table,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
20
services/batch-agent/requirements.txt
Normal file
20
services/batch-agent/requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
redis>=5.0.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
langchain-core>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
openai>=1.50.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
croniter>=2.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.30.0
|
||||||
|
msal>=1.28.0
|
||||||
Reference in New Issue
Block a user