WsDeviceHello.agent_ids → scout_ids in Pydantic schema, device_ws.py handler, and all test fixtures (test_device_ws, test_ws_unified, test_memory_middleware). Also fixes stale CloudAgentConfig reference in gmail.py docstring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
336 lines
13 KiB
Python
336 lines
13 KiB
Python
"""Gmail API client for cloud agent integration.
|
|
|
|
Wraps the Google Gmail REST API to fetch email messages matching a
|
|
``filter_config`` dict. Uses the official ``google-api-python-client``
|
|
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
|
blocking the event loop.
|
|
|
|
Token refresh is handled transparently: when the stored access token has
|
|
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
|
token to obtain a fresh one. The caller is responsible for persisting
|
|
any refreshed credentials back to ``CloudScoutConfig.oauth_token_encrypted``
|
|
(see ``agent_runner.run_cloud_agent``).
|
|
|
|
Credential dict shape (Google OAuth2):
|
|
{
|
|
"token": "<access_token>",
|
|
"refresh_token": "<refresh_token>",
|
|
"token_uri": "https://oauth2.googleapis.com/token",
|
|
"client_id": "<client_id>",
|
|
"client_secret": "<client_secret>",
|
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
|
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
|
}
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import email
|
|
import html
|
|
import logging
|
|
import re
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from app.integrations import EmailMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Gmail search date format — e.g. "after:2025/01/01"
|
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
|
|
|
# Maximum characters of body text forwarded to the LLM.
|
|
_BODY_TRUNCATE = 8_000
|
|
|
|
# Maximum messages retrieved per run (prevents runaway quota usage).
|
|
_MAX_MESSAGES = 200
|
|
|
|
|
|
def _build_gmail_query(
|
|
filter_config: dict[str, Any] | None,
|
|
since: datetime | None,
|
|
) -> str:
|
|
"""Build a Gmail search query string from *filter_config* and *since*.
|
|
|
|
Supported ``filter_config`` keys:
|
|
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
|
senders (list[str]): Sender addresses or domains to include
|
|
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
|
|
|
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
|
when it is earlier.
|
|
"""
|
|
parts: list[str] = []
|
|
cfg = filter_config or {}
|
|
|
|
# Labels — joined with OR when multiple given.
|
|
labels: list[str] = cfg.get("labels", [])
|
|
if labels:
|
|
if len(labels) == 1:
|
|
parts.append(f"label:{labels[0]}")
|
|
else:
|
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
|
parts.append(f"({label_expr})")
|
|
|
|
# Senders — each prefixed with "from:".
|
|
senders: list[str] = cfg.get("senders", [])
|
|
for sender in senders:
|
|
parts.append(f"from:{sender}")
|
|
|
|
# Date range.
|
|
date_range: dict = cfg.get("date_range", {})
|
|
from_str: str | None = date_range.get("from")
|
|
to_str: str | None = date_range.get("to")
|
|
|
|
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
|
effective_since: datetime | None = since
|
|
if from_str:
|
|
try:
|
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
|
if cfg_since.tzinfo is None:
|
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
|
if effective_since is None or cfg_since > effective_since:
|
|
effective_since = cfg_since
|
|
except ValueError:
|
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
|
|
|
if effective_since:
|
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
|
|
|
if to_str:
|
|
try:
|
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
|
except ValueError:
|
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
|
|
|
return " ".join(parts)
|
|
|
|
|
|
def _strip_html(raw_html: str) -> str:
|
|
"""Remove HTML tags and decode entities to get plain text."""
|
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
|
decoded = html.unescape(no_tags)
|
|
return re.sub(r"\s+", " ", decoded).strip()
|
|
|
|
|
|
def _parse_body(payload: dict[str, Any]) -> str:
|
|
"""Recursively extract the plain-text body from a Gmail message payload.
|
|
|
|
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
|
Returns an empty string if no body can be extracted.
|
|
"""
|
|
mime_type: str = payload.get("mimeType", "")
|
|
body: dict = payload.get("body", {})
|
|
parts: list[dict] = payload.get("parts", [])
|
|
|
|
if mime_type == "text/plain":
|
|
data = body.get("data", "")
|
|
if data:
|
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
|
return ""
|
|
|
|
if mime_type == "text/html":
|
|
data = body.get("data", "")
|
|
if data:
|
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
|
return _strip_html(raw)
|
|
return ""
|
|
|
|
# Multipart — prefer text/plain part, fall back to text/html.
|
|
plain_fallback = ""
|
|
for part in parts:
|
|
part_mime = part.get("mimeType", "")
|
|
if part_mime == "text/plain":
|
|
return _parse_body(part)
|
|
if part_mime == "text/html" and not plain_fallback:
|
|
plain_fallback = _parse_body(part)
|
|
if part_mime.startswith("multipart/"):
|
|
nested = _parse_body(part)
|
|
if nested:
|
|
return nested
|
|
return plain_fallback
|
|
|
|
|
|
def _parse_date(raw: str) -> datetime:
|
|
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
|
try:
|
|
parsed = email.utils.parsedate_to_datetime(raw)
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
|
return parsed.astimezone(timezone.utc)
|
|
except Exception:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
class GmailClient:
|
|
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
|
|
|
Parameters
|
|
----------
|
|
credentials_info:
|
|
Decrypted OAuth2 credential dict. Must contain at minimum
|
|
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
|
``client_id`` + ``client_secret``.
|
|
"""
|
|
|
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
|
from google.oauth2.credentials import Credentials
|
|
|
|
self._credentials_info = credentials_info
|
|
expiry_str: str | None = credentials_info.get("expiry")
|
|
expiry: datetime | None = None
|
|
if expiry_str:
|
|
try:
|
|
expiry = datetime.fromisoformat(
|
|
expiry_str.replace("Z", "+00:00")
|
|
).replace(tzinfo=timezone.utc)
|
|
except ValueError:
|
|
pass
|
|
|
|
self._credentials = Credentials(
|
|
token=credentials_info.get("token"),
|
|
refresh_token=credentials_info.get("refresh_token"),
|
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
|
client_id=credentials_info.get("client_id"),
|
|
client_secret=credentials_info.get("client_secret"),
|
|
scopes=credentials_info.get("scopes"),
|
|
expiry=expiry,
|
|
)
|
|
|
|
# ── Public API ─────────────────────────────────────────────────────────
|
|
|
|
async def fetch_messages(
|
|
self,
|
|
filter_config: dict[str, Any] | None = None,
|
|
since: datetime | None = None,
|
|
) -> list[EmailMessage]:
|
|
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
|
|
|
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
|
to avoid blocking the async event loop.
|
|
|
|
Token refresh is performed automatically when the access token has
|
|
expired. After the call, ``self.refreshed_credentials`` may be
|
|
consulted to detect whether new credentials should be persisted.
|
|
"""
|
|
query = _build_gmail_query(filter_config, since)
|
|
logger.debug("gmail: executing search query %r", query)
|
|
return await asyncio.to_thread(self._fetch_sync, query)
|
|
|
|
@property
|
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
|
"""Return updated credential dict if the access token was refreshed.
|
|
|
|
If the credentials were refreshed during ``fetch_messages()``, returns
|
|
a new dict that should be re-encrypted and written back to the DB.
|
|
Returns ``None`` if no refresh occurred.
|
|
"""
|
|
creds = self._credentials
|
|
if not creds.valid and creds.expired:
|
|
return None
|
|
# Check whether the token changed from what was stored.
|
|
if creds.token != self._credentials_info.get("token"):
|
|
result = {
|
|
"token": creds.token,
|
|
"refresh_token": creds.refresh_token,
|
|
"token_uri": creds.token_uri,
|
|
"client_id": creds.client_id,
|
|
"client_secret": creds.client_secret,
|
|
"scopes": list(creds.scopes or []),
|
|
}
|
|
if creds.expiry:
|
|
result["expiry"] = creds.expiry.isoformat()
|
|
return result
|
|
return None
|
|
|
|
# ── Internal sync worker ───────────────────────────────────────────────
|
|
|
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
|
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
|
import googleapiclient.discovery
|
|
import googleapiclient.errors
|
|
from google.auth.transport.requests import Request
|
|
|
|
# Refresh token if needed before building the service.
|
|
if self._credentials.expired and self._credentials.refresh_token:
|
|
try:
|
|
self._credentials.refresh(Request())
|
|
except Exception as exc:
|
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
|
|
|
service = googleapiclient.discovery.build(
|
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
|
)
|
|
user_api = service.users() # type: ignore[attr-defined]
|
|
|
|
# ── List matching message IDs ──────────────────────────────────────
|
|
ids: list[str] = []
|
|
page_token: str | None = None
|
|
while len(ids) < _MAX_MESSAGES:
|
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
|
kwargs: dict[str, Any] = {
|
|
"userId": "me",
|
|
"maxResults": batch_size,
|
|
}
|
|
if query:
|
|
kwargs["q"] = query
|
|
if page_token:
|
|
kwargs["pageToken"] = page_token
|
|
|
|
try:
|
|
resp = user_api.messages().list(**kwargs).execute()
|
|
except googleapiclient.errors.HttpError as exc:
|
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
|
|
|
for msg in resp.get("messages", []):
|
|
ids.append(msg["id"])
|
|
|
|
page_token = resp.get("nextPageToken")
|
|
if not page_token:
|
|
break
|
|
|
|
if not ids:
|
|
logger.debug("gmail: no messages matched query %r", query)
|
|
return []
|
|
|
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
|
|
|
# ── Fetch individual message details ──────────────────────────────
|
|
messages: list[EmailMessage] = []
|
|
for msg_id in ids:
|
|
try:
|
|
msg = user_api.messages().get(
|
|
userId="me", id=msg_id, format="full"
|
|
).execute()
|
|
|
|
headers: dict[str, str] = {
|
|
h["name"].lower(): h["value"]
|
|
for h in msg.get("payload", {}).get("headers", [])
|
|
}
|
|
subject = headers.get("subject", "(no subject)")
|
|
sender = headers.get("from", "unknown")
|
|
date_raw = headers.get("date", "")
|
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
|
|
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
|
labels = msg.get("labelIds", [])
|
|
|
|
messages.append(EmailMessage(
|
|
id=msg_id,
|
|
subject=subject,
|
|
sender=sender,
|
|
body_text=body_text,
|
|
date=date,
|
|
labels=labels,
|
|
))
|
|
except googleapiclient.errors.HttpError as exc:
|
|
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
|
except Exception as exc:
|
|
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
|
|
|
logger.info("gmail: returned %d message(s)", len(messages))
|
|
return messages
|