feat(step-3.6): cloud provider integrations (Gmail, Outlook, Teams)
- Add app/integrations/__init__.py: Fernet token encryption helpers, EmailMessage/ChatMessage dataclasses, get_provider() factory - Add app/integrations/gmail.py: GmailClient with async fetch_messages(), token refresh, configurable label/sender/date filters - Add app/integrations/ms_graph.py: MSGraphClient with fetch_emails() (Outlook) and fetch_messages() (Teams), MSAL token refresh, OData filters - Update app/core/agent_runner.py: replace run_cloud_agent() stub with full 8-step implementation; extend _finalize_run() for cloud config type - Update app/config/settings.py: add OAuth + Fernet encryption settings - Update requirements.txt: google-api-python-client, google-auth-*, msal, cryptography - Add tests/test_integrations.py: 47 tests covering all integration code - Update tests/test_agent_runner.py: replace stub test with 7 real tests All 76 new/updated tests pass.
This commit is contained in:
335
app/integrations/gmail.py
Normal file
335
app/integrations/gmail.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Gmail API client for cloud agent integration.
|
||||
|
||||
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Token refresh is handled transparently: when the stored access token has
|
||||
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||
token to obtain a fresh one. The caller is responsible for persisting
|
||||
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||
(see ``agent_runner.run_cloud_agent``).
|
||||
|
||||
Credential dict shape (Google OAuth2):
|
||||
{
|
||||
"token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "<client_id>",
|
||||
"client_secret": "<client_secret>",
|
||||
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import email
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from app.integrations import EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gmail search date format — e.g. "after:2025/01/01"
|
||||
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||
|
||||
# Maximum characters of body text forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
|
||||
def _build_gmail_query(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||
senders (list[str]): Sender addresses or domains to include
|
||||
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||
|
||||
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||
when it is earlier.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Labels — joined with OR when multiple given.
|
||||
labels: list[str] = cfg.get("labels", [])
|
||||
if labels:
|
||||
if len(labels) == 1:
|
||||
parts.append(f"label:{labels[0]}")
|
||||
else:
|
||||
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||
parts.append(f"({label_expr})")
|
||||
|
||||
# Senders — each prefixed with "from:".
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
for sender in senders:
|
||||
parts.append(f"from:{sender}")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
to_str: str | None = date_range.get("to")
|
||||
|
||||
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||
effective_since: datetime | None = since
|
||||
if from_str:
|
||||
try:
|
||||
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||
if cfg_since.tzinfo is None:
|
||||
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||
if effective_since is None or cfg_since > effective_since:
|
||||
effective_since = cfg_since
|
||||
except ValueError:
|
||||
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||
|
||||
if effective_since:
|
||||
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||
|
||||
if to_str:
|
||||
try:
|
||||
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||
except ValueError:
|
||||
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _strip_html(raw_html: str) -> str:
|
||||
"""Remove HTML tags and decode entities to get plain text."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||
decoded = html.unescape(no_tags)
|
||||
return re.sub(r"\s+", " ", decoded).strip()
|
||||
|
||||
|
||||
def _parse_body(payload: dict[str, Any]) -> str:
|
||||
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||
|
||||
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||
Returns an empty string if no body can be extracted.
|
||||
"""
|
||||
mime_type: str = payload.get("mimeType", "")
|
||||
body: dict = payload.get("body", {})
|
||||
parts: list[dict] = payload.get("parts", [])
|
||||
|
||||
if mime_type == "text/plain":
|
||||
data = body.get("data", "")
|
||||
if data:
|
||||
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return ""
|
||||
|
||||
if mime_type == "text/html":
|
||||
data = body.get("data", "")
|
||||
if data:
|
||||
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return _strip_html(raw)
|
||||
return ""
|
||||
|
||||
# Multipart — prefer text/plain part, fall back to text/html.
|
||||
plain_fallback = ""
|
||||
for part in parts:
|
||||
part_mime = part.get("mimeType", "")
|
||||
if part_mime == "text/plain":
|
||||
return _parse_body(part)
|
||||
if part_mime == "text/html" and not plain_fallback:
|
||||
plain_fallback = _parse_body(part)
|
||||
if part_mime.startswith("multipart/"):
|
||||
nested = _parse_body(part)
|
||||
if nested:
|
||||
return nested
|
||||
return plain_fallback
|
||||
|
||||
|
||||
def _parse_date(raw: str) -> datetime:
|
||||
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||
try:
|
||||
parsed = email.utils.parsedate_to_datetime(raw)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed.astimezone(timezone.utc)
|
||||
except Exception:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class GmailClient:
|
||||
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||
``client_id`` + ``client_secret``.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
self._credentials_info = credentials_info
|
||||
expiry_str: str | None = credentials_info.get("expiry")
|
||||
expiry: datetime | None = None
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry = datetime.fromisoformat(
|
||||
expiry_str.replace("Z", "+00:00")
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._credentials = Credentials(
|
||||
token=credentials_info.get("token"),
|
||||
refresh_token=credentials_info.get("refresh_token"),
|
||||
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||
client_id=credentials_info.get("client_id"),
|
||||
client_secret=credentials_info.get("client_secret"),
|
||||
scopes=credentials_info.get("scopes"),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_messages(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||
|
||||
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||
to avoid blocking the async event loop.
|
||||
|
||||
Token refresh is performed automatically when the access token has
|
||||
expired. After the call, ``self.refreshed_credentials`` may be
|
||||
consulted to detect whether new credentials should be persisted.
|
||||
"""
|
||||
query = _build_gmail_query(filter_config, since)
|
||||
logger.debug("gmail: executing search query %r", query)
|
||||
return await asyncio.to_thread(self._fetch_sync, query)
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||
a new dict that should be re-encrypted and written back to the DB.
|
||||
Returns ``None`` if no refresh occurred.
|
||||
"""
|
||||
creds = self._credentials
|
||||
if not creds.valid and creds.expired:
|
||||
return None
|
||||
# Check whether the token changed from what was stored.
|
||||
if creds.token != self._credentials_info.get("token"):
|
||||
result = {
|
||||
"token": creds.token,
|
||||
"refresh_token": creds.refresh_token,
|
||||
"token_uri": creds.token_uri,
|
||||
"client_id": creds.client_id,
|
||||
"client_secret": creds.client_secret,
|
||||
"scopes": list(creds.scopes or []),
|
||||
}
|
||||
if creds.expiry:
|
||||
result["expiry"] = creds.expiry.isoformat()
|
||||
return result
|
||||
return None
|
||||
|
||||
# ── Internal sync worker ───────────────────────────────────────────────
|
||||
|
||||
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||
import googleapiclient.discovery
|
||||
import googleapiclient.errors
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
# Refresh token if needed before building the service.
|
||||
if self._credentials.expired and self._credentials.refresh_token:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||
|
||||
service = googleapiclient.discovery.build(
|
||||
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||
)
|
||||
user_api = service.users() # type: ignore[attr-defined]
|
||||
|
||||
# ── List matching message IDs ──────────────────────────────────────
|
||||
ids: list[str] = []
|
||||
page_token: str | None = None
|
||||
while len(ids) < _MAX_MESSAGES:
|
||||
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||
kwargs: dict[str, Any] = {
|
||||
"userId": "me",
|
||||
"maxResults": batch_size,
|
||||
}
|
||||
if query:
|
||||
kwargs["q"] = query
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
try:
|
||||
resp = user_api.messages().list(**kwargs).execute()
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||
|
||||
for msg in resp.get("messages", []):
|
||||
ids.append(msg["id"])
|
||||
|
||||
page_token = resp.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
if not ids:
|
||||
logger.debug("gmail: no messages matched query %r", query)
|
||||
return []
|
||||
|
||||
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||
|
||||
# ── Fetch individual message details ──────────────────────────────
|
||||
messages: list[EmailMessage] = []
|
||||
for msg_id in ids:
|
||||
try:
|
||||
msg = user_api.messages().get(
|
||||
userId="me", id=msg_id, format="full"
|
||||
).execute()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in msg.get("payload", {}).get("headers", [])
|
||||
}
|
||||
subject = headers.get("subject", "(no subject)")
|
||||
sender = headers.get("from", "unknown")
|
||||
date_raw = headers.get("date", "")
|
||||
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||
|
||||
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||
labels = msg.get("labelIds", [])
|
||||
|
||||
messages.append(EmailMessage(
|
||||
id=msg_id,
|
||||
subject=subject,
|
||||
sender=sender,
|
||||
body_text=body_text,
|
||||
date=date,
|
||||
labels=labels,
|
||||
))
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||
except Exception as exc:
|
||||
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||
|
||||
logger.info("gmail: returned %d message(s)", len(messages))
|
||||
return messages
|
||||
Reference in New Issue
Block a user