feat(step-3.6): cloud provider integrations (Gmail, Outlook, Teams)

- Add app/integrations/__init__.py: Fernet token encryption helpers,
  EmailMessage/ChatMessage dataclasses, get_provider() factory
- Add app/integrations/gmail.py: GmailClient with async fetch_messages(),
  token refresh, configurable label/sender/date filters
- Add app/integrations/ms_graph.py: MSGraphClient with fetch_emails()
  (Outlook) and fetch_messages() (Teams), MSAL token refresh, OData filters
- Update app/core/agent_runner.py: replace run_cloud_agent() stub with
  full 8-step implementation; extend _finalize_run() for cloud config type
- Update app/config/settings.py: add OAuth + Fernet encryption settings
- Update requirements.txt: google-api-python-client, google-auth-*,
  msal, cryptography
- Add tests/test_integrations.py: 47 tests covering all integration code
- Update tests/test_agent_runner.py: replace stub test with 7 real tests

All 76 new/updated tests pass.
This commit is contained in:
2026-03-05 18:05:07 +01:00
parent 24772f2b67
commit a775a2da18
11 changed files with 2063 additions and 35 deletions

View File

@@ -437,21 +437,21 @@ Cloud Agent:
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅ - **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
### Step 3.6 — Cloud provider integrations ### Step 3.6 — Cloud provider integrations
- [ ] Create `app/integrations/gmail.py`: - [x] Create `app/integrations/gmail.py`:
- `GmailClient`: - `GmailClient`:
- `__init__(oauth_token)` — initializes Google API client - `__init__(oauth_token)` — initializes Google API client
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]` - `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }` - `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
- Handles token refresh via Google OAuth2 refresh flow - Handles token refresh via Google OAuth2 refresh flow
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders` - Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
- [ ] Create `app/integrations/ms_graph.py`: - [x] Create `app/integrations/ms_graph.py`:
- `MSGraphClient`: - `MSGraphClient`:
- `__init__(oauth_token)` — initializes MS Graph client - `__init__(oauth_token)` — initializes MS Graph client
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook) - `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams) - `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
- `ChatMessage`: `{ id, content, sender, channel, date }` - `ChatMessage`: `{ id, content, sender, channel, date }`
- Handles token refresh via MSAL - Handles token refresh via MSAL
- [ ] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient` - [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal` - **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py` - **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams. - **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.

View File

@@ -29,6 +29,25 @@ class Settings(BaseSettings):
LLM_MODEL: str = "gpt-4o" LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini" LLM_ROUTER_MODEL: str = "gpt-4o-mini"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
GITHUB_COPILOT_TOKEN_DIR: str = ""
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]

View File

@@ -29,7 +29,7 @@ import asyncio
import json import json
import logging import logging
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
from croniter import croniter from croniter import croniter
@@ -383,7 +383,10 @@ async def run_local_agent(
) )
# ── Cloud agent runner (stub) ─────────────────────────────────────────────── # ── Cloud agent runner ─────────────────────────────────────────────────────
# Default lookback window when an agent has never run before.
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent( async def run_cloud_agent(
@@ -392,26 +395,199 @@ async def run_cloud_agent(
run_log: AgentRunLog, run_log: AgentRunLog,
device_mgr: DeviceConnectionManager, device_mgr: DeviceConnectionManager,
) -> None: ) -> None:
"""Execute a cloud connector agent run. """Execute a cloud connector agent run end-to-end.
.. note:: Steps:
This is a **stub** — provider integrations (Gmail, Teams, Outlook)
are implemented in Step 3.6. The run is immediately marked as an 1. Verify the user's device is online — results are pushed to Electron
error with an informative message. via WS tool-call frames. If no device is connected, abort.
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
3. Instantiate the provider client (Gmail or MS Graph).
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
the first run) applying ``config.filter_config`` filters.
5. For each message/email call ``_extract_items_from_content`` with
``config.prompt_template`` to get structured ``{table, data}`` items.
6. Push each item to Electron as an ``insert`` tool-call.
7. If the provider refreshed its access token, re-encrypt and write it
back to ``config.oauth_token_encrypted``.
8. Persist the run outcome via ``_finalize_run``.
""" """
run_id = run_log.id
# ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id):
logger.info(
"agent_runner: skip cloud run=%s — no device online for user=%s",
run_id,
user_id,
)
await _finalize_run(
run_log,
status="error",
errors=["No connected device — cloud agent results cannot be delivered"],
)
return
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
from app.integrations import decrypt_token, encrypt_token, get_provider
if not config.oauth_token_encrypted:
await _finalize_run(
run_log,
status="error",
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
)
return
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
await _finalize_run(
run_log,
status="error",
errors=[f"Failed to decrypt OAuth token: {exc}"],
)
return
# ── 3. Instantiate provider client ────────────────────────────────
try:
provider = get_provider(config.provider, credentials_info)
except ValueError as exc:
await _finalize_run(
run_log,
status="error",
errors=[str(exc)],
)
return
# ── 4. Fetch messages ─────────────────────────────────────────────
since: datetime | None = config.last_run_at
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
if since.tzinfo is None:
since = since.replace(tzinfo=timezone.utc)
errors: list[str] = []
items_processed = 0
items_created = 0
try:
if config.provider == "gmail":
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
filter_config=config.filter_config,
since=since,
)
elif config.provider == "outlook":
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
filter_config=config.filter_config,
since=since,
)
elif config.provider == "teams":
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
filter_config=config.filter_config,
since=since,
)
else:
raw_messages = []
except RuntimeError as exc:
logger.error(
"agent_runner: provider fetch failed for cloud agent %s: %s",
config.id,
exc,
)
await _finalize_run(
run_log,
status="error",
errors=[f"Provider fetch failed: {exc}"],
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
return
logger.info( logger.info(
"agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6", "agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
config.id, config.id,
len(raw_messages),
config.provider, config.provider,
user_id, user_id,
) )
# ── 56. Extract + insert ─────────────────────────────────────────
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
try:
extracted = await _extract_items_from_content(
config.prompt_template, content_text, config.data_types
)
except Exception as exc:
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
continue
for item in extracted:
try:
result = await _send_insert_to_client(
user_id, item["table"], item["data"], device_mgr
)
if result.get("error"):
errors.append(
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
)
else:
items_created += 1
except asyncio.TimeoutError:
errors.append(
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
)
except RuntimeError as exc:
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
# ── 7. Persist refreshed token (if any) ───────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
if refreshed:
try:
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
except Exception as exc:
logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc)
# ── 8. Finalise ────────────────────────────────────────────────────
if errors and items_created == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run( await _finalize_run(
run_log, run_log,
status="error", status=final_status,
errors=[ items_processed=items_processed,
f"Cloud provider integrations for '{config.provider}' are not yet " items_created=items_created,
"implemented. This feature arrives in Step 3.6." errors=errors,
], update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
logger.info(
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
run_id,
final_status,
items_processed,
items_created,
len(errors),
) )
@@ -519,13 +695,21 @@ async def _finalize_run(
managed.errors = errors or [] managed.errors = errors or []
managed.completed_at = now managed.completed_at = now
if update_config_last_run and config_id and config_type == "local": if update_config_last_run and config_id:
cfg_result = await db.execute( if config_type == "local":
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) cfg_result = await db.execute(
) select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
cfg = cfg_result.scalar_one_or_none() )
if cfg: cfg = cfg_result.scalar_one_or_none()
cfg.last_run_at = now if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit() await db.commit()
except Exception as exc: except Exception as exc:

View File

@@ -17,7 +17,10 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
from __future__ import annotations from __future__ import annotations
import os
from openai import AsyncOpenAI from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from litellm import get_supported_openai_params # noqa: F401 validates install from litellm import get_supported_openai_params # noqa: F401 validates install
@@ -31,6 +34,10 @@ def _api_key_for_model(model: str) -> str | None:
return settings.ANTHROPIC_API_KEY or None return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"): if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None return settings.GOOGLE_API_KEY or None
if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth.
return None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o") # Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None return settings.OPENAI_API_KEY or None
@@ -55,6 +62,11 @@ def get_llm(
Sampling temperature. ``0`` = deterministic. Sampling temperature. ``0`` = deterministic.
""" """
model = model or settings.LLM_MODEL model = model or settings.LLM_MODEL
# Point LiteLLM to the custom token directory when configured.
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
return ChatOpenAI( return ChatOpenAI(
model=model, model=model,
temperature=temperature, temperature=temperature,
@@ -71,10 +83,22 @@ def get_router_llm(
async def embed(text: str) -> list[float]: async def embed(text: str) -> list[float]:
"""Return a 1536-dim embedding vector for *text* using text-embedding-3-small.""" """Return an embedding vector for *text*.
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
model names to preserve existing behaviour.
"""
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
# so the provider's auth mechanism is applied correctly.
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create( response = await client.embeddings.create(model=model, input=text)
model="text-embedding-3-small",
input=text,
)
return response.data[0].embedding return response.data[0].embedding

View File

@@ -0,0 +1,164 @@
"""Cloud provider integration utilities.
Provides:
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
* ``get_provider()`` — factory that returns the correct client given a
provider name and decrypted OAuth credentials dict.
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
encryption for OAuth tokens stored in ``cloud_agent_configs``.
Encryption rationale
--------------------
Unlike user content (which is E2E-encrypted client-side and **never**
decrypted server-side), OAuth tokens *must* be decrypted server-side
because the backend makes provider API calls on behalf of the user.
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
is never returned to clients.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from app.config.settings import settings
if TYPE_CHECKING:
from app.integrations.gmail import GmailClient
from app.integrations.ms_graph import MSGraphClient
logger = logging.getLogger(__name__)
# ── Shared message types ──────────────────────────────────────────────────
@dataclass
class EmailMessage:
"""A single email message fetched from Gmail or Outlook."""
id: str
subject: str
sender: str
body_text: str
date: datetime
labels: list[str] = field(default_factory=list)
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{labels_str}\n"
f"Subject: {self.subject}\n\n"
f"{self.body_text}"
)
@dataclass
class ChatMessage:
"""A single Teams chat or channel message fetched from MS Graph."""
id: str
content: str
sender: str
channel: str | None
date: datetime
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{channel_str}\n\n"
f"{self.content}"
)
# ── Fernet helpers ────────────────────────────────────────────────────────
def _get_fernet() -> Fernet:
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
must ensure this is configured before persisting OAuth tokens.
"""
key = settings.OAUTH_ENCRYPTION_KEY
if not key:
raise RuntimeError(
"OAUTH_ENCRYPTION_KEY is not set. "
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
)
return Fernet(key.encode() if isinstance(key, str) else key)
def encrypt_token(token_info: dict) -> str:
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
Stores the full ``{access_token, refresh_token, token_uri, client_id,
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: ``token_info`` is not a non-empty dict.
"""
if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8")
return _get_fernet().encrypt(plaintext).decode("utf-8")
def decrypt_token(encrypted: str) -> dict:
"""Decrypt a Fernet-encrypted token string and return the credential dict.
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: The encrypted string is invalid or was encrypted with a
different key.
"""
try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext)
except (InvalidToken, json.JSONDecodeError) as exc:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
# ── Provider factory ──────────────────────────────────────────────────────
def get_provider(
provider: str,
credentials_info: dict,
) -> "GmailClient | MSGraphClient":
"""Return the correct provider client for *provider*.
Parameters
----------
provider:
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
credentials_info:
Decrypted OAuth credential dict (Google or Microsoft shape).
Raises:
ValueError: Unknown provider name.
"""
if provider == "gmail":
from app.integrations.gmail import GmailClient
return GmailClient(credentials_info)
if provider in {"outlook", "teams"}:
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(credentials_info)
raise ValueError(
f"Unknown cloud provider {provider!r}. "
"Supported: 'gmail', 'outlook', 'teams'."
)

335
app/integrations/gmail.py Normal file
View File

@@ -0,0 +1,335 @@
"""Gmail API client for cloud agent integration.
Wraps the Google Gmail REST API to fetch email messages matching a
``filter_config`` dict. Uses the official ``google-api-python-client``
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
blocking the event loop.
Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2):
{
"token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "<client_id>",
"client_secret": "<client_secret>",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
}
"""
from __future__ import annotations
import asyncio
import base64
import email
import html
import logging
import re
from datetime import datetime, timezone
from typing import Any
from app.integrations import EmailMessage
logger = logging.getLogger(__name__)
# Gmail search date format — e.g. "after:2025/01/01"
_GMAIL_DATE_FMT = "%Y/%m/%d"
# Maximum characters of body text forwarded to the LLM.
_BODY_TRUNCATE = 8_000
# Maximum messages retrieved per run (prevents runaway quota usage).
_MAX_MESSAGES = 200
def _build_gmail_query(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build a Gmail search query string from *filter_config* and *since*.
Supported ``filter_config`` keys:
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
senders (list[str]): Sender addresses or domains to include
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
A hard ``since`` date (from last run) always overrides ``date_range.from``
when it is earlier.
"""
parts: list[str] = []
cfg = filter_config or {}
# Labels — joined with OR when multiple given.
labels: list[str] = cfg.get("labels", [])
if labels:
if len(labels) == 1:
parts.append(f"label:{labels[0]}")
else:
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})")
# Senders — each prefixed with "from:".
senders: list[str] = cfg.get("senders", [])
for sender in senders:
parts.append(f"from:{sender}")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to")
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
if effective_since:
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
except ValueError:
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
return " ".join(parts)
def _strip_html(raw_html: str) -> str:
"""Remove HTML tags and decode entities to get plain text."""
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str:
"""Recursively extract the plain-text body from a Gmail message payload.
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
Returns an empty string if no body can be extracted.
"""
mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", [])
if mime_type == "text/plain":
data = body.get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type == "text/html":
data = body.get("data", "")
if data:
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return _strip_html(raw)
return ""
# Multipart — prefer text/plain part, fall back to text/html.
plain_fallback = ""
for part in parts:
part_mime = part.get("mimeType", "")
if part_mime == "text/plain":
return _parse_body(part)
if part_mime == "text/html" and not plain_fallback:
plain_fallback = _parse_body(part)
if part_mime.startswith("multipart/"):
nested = _parse_body(part)
if nested:
return nested
return plain_fallback
def _parse_date(raw: str) -> datetime:
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
try:
parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
except Exception:
return datetime.now(timezone.utc)
class GmailClient:
"""Fetch email messages from a Gmail account via the Gmail REST API.
Parameters
----------
credentials_info:
Decrypted OAuth2 credential dict. Must contain at minimum
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
``client_id`` + ``client_secret``.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials
self._credentials_info = credentials_info
expiry_str: str | None = credentials_info.get("expiry")
expiry: datetime | None = None
if expiry_str:
try:
expiry = datetime.fromisoformat(
expiry_str.replace("Z", "+00:00")
).replace(tzinfo=timezone.utc)
except ValueError:
pass
self._credentials = Credentials(
token=credentials_info.get("token"),
refresh_token=credentials_info.get("refresh_token"),
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=credentials_info.get("client_id"),
client_secret=credentials_info.get("client_secret"),
scopes=credentials_info.get("scopes"),
expiry=expiry,
)
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
to avoid blocking the async event loop.
Token refresh is performed automatically when the access token has
expired. After the call, ``self.refreshed_credentials`` may be
consulted to detect whether new credentials should be persisted.
"""
query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query)
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
If the credentials were refreshed during ``fetch_messages()``, returns
a new dict that should be re-encrypted and written back to the DB.
Returns ``None`` if no refresh occurred.
"""
creds = self._credentials
if not creds.valid and creds.expired:
return None
# Check whether the token changed from what was stored.
if creds.token != self._credentials_info.get("token"):
result = {
"token": creds.token,
"refresh_token": creds.refresh_token,
"token_uri": creds.token_uri,
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"scopes": list(creds.scopes or []),
}
if creds.expiry:
result["expiry"] = creds.expiry.isoformat()
return result
return None
# ── Internal sync worker ───────────────────────────────────────────────
def _fetch_sync(self, query: str) -> list[EmailMessage]:
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
import googleapiclient.discovery
import googleapiclient.errors
from google.auth.transport.requests import Request
# Refresh token if needed before building the service.
if self._credentials.expired and self._credentials.refresh_token:
try:
self._credentials.refresh(Request())
except Exception as exc:
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False
)
user_api = service.users() # type: ignore[attr-defined]
# ── List matching message IDs ──────────────────────────────────────
ids: list[str] = []
page_token: str | None = None
while len(ids) < _MAX_MESSAGES:
batch_size = min(100, _MAX_MESSAGES - len(ids))
kwargs: dict[str, Any] = {
"userId": "me",
"maxResults": batch_size,
}
if query:
kwargs["q"] = query
if page_token:
kwargs["pageToken"] = page_token
try:
resp = user_api.messages().list(**kwargs).execute()
except googleapiclient.errors.HttpError as exc:
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
for msg in resp.get("messages", []):
ids.append(msg["id"])
page_token = resp.get("nextPageToken")
if not page_token:
break
if not ids:
logger.debug("gmail: no messages matched query %r", query)
return []
logger.info("gmail: fetching %d message(s)", len(ids))
# ── Fetch individual message details ──────────────────────────────
messages: list[EmailMessage] = []
for msg_id in ids:
try:
msg = user_api.messages().get(
userId="me", id=msg_id, format="full"
).execute()
headers: dict[str, str] = {
h["name"].lower(): h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
subject = headers.get("subject", "(no subject)")
sender = headers.get("from", "unknown")
date_raw = headers.get("date", "")
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
labels = msg.get("labelIds", [])
messages.append(EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
body_text=body_text,
date=date,
labels=labels,
))
except googleapiclient.errors.HttpError as exc:
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
except Exception as exc:
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages))
return messages

View File

@@ -0,0 +1,352 @@
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
Handles two data sources:
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
``/me/chats/getAllMessages`` filtered by date.
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
dependency) is used for all API calls.
Credential dict shape (Microsoft OAuth2 / MSAL):
{
"access_token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_type": "Bearer",
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
"expires_in": 3600
}
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
import httpx
from app.config.settings import settings
from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
# Max items fetched per run.
_MAX_EMAILS = 200
_MAX_MESSAGES = 200
# Max characters of body forwarded to the LLM.
_BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str:
"""Strip HTML tags and collapse whitespace."""
no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html
decoded = _html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _odata_datetime(dt: datetime) -> str:
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
def _build_email_filter(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
Supported ``filter_config`` keys:
senders (list[str]): Sender email addresses.
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
folders (list[str]): Folder display names (not directly filterable
via OData, so ignored here — callers iterate
folder IDs separately if needed; listed for
completeness).
A hard ``since`` date always overrides ``date_range.from`` when it is
earlier.
"""
clauses: list[str] = []
cfg = filter_config or {}
# Senders.
senders: list[str] = cfg.get("senders", [])
if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
if effective_since:
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
to_str: str | None = date_range.get("to")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
if to_dt.tzinfo is None:
to_dt = to_dt.replace(tzinfo=timezone.utc)
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
except ValueError:
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
return " and ".join(clauses)
class MSGraphClient:
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
Parameters
----------
credentials_info:
Decrypted MSAL credential dict.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token")
# ── Token management ───────────────────────────────────────────────────
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None:
"""Use MSAL to exchange the refresh token for a fresh access token.
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
Raises:
RuntimeError: MSAL reports an auth error.
"""
import msal
app = msal.ConfidentialClientApplication(
client_id=settings.MS_CLIENT_ID,
client_credential=settings.MS_CLIENT_SECRET,
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
)
scopes: list[str] = self._credentials_info.get("scope", "").split()
if not scopes:
scopes = ["https://graph.microsoft.com/.default"]
result = app.acquire_token_by_refresh_token(
self._refresh_token,
scopes=scopes,
)
if "access_token" not in result:
error = result.get("error_description", result.get("error", "unknown"))
raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"]
# MSAL may issue a new refresh token.
if "refresh_token" in result:
self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"]
self._credentials_info["access_token"] = self._access_token
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
Returns ``None`` if no change was made.
"""
if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token}
return None
# ── HTTP helpers ───────────────────────────────────────────────────────
async def _get(
self,
client: httpx.AsyncClient,
url: str,
params: dict[str, Any] | None = None,
*,
retry_on_401: bool = True,
) -> dict[str, Any]:
"""GET *url* with auth; refresh token on 401 and retry once."""
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
logger.debug("ms_graph: 401 on %s — refreshing token", url)
await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429:
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
resp.raise_for_status()
return resp.json()
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_emails(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
Parameters
----------
filter_config:
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
since:
Hard lower-bound on email date (from last agent run).
"""
odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = {
"$top": 50,
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
"$orderby": "receivedDateTime desc",
}
if odata_filter:
params["$filter"] = odata_filter
emails: list[EmailMessage] = []
url = f"{_GRAPH_BASE}/me/messages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(emails) < _MAX_EMAILS:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
for item in data.get("value", []):
emails.append(self._parse_email(item))
if len(emails) >= _MAX_EMAILS:
break
url = data.get("@odata.nextLink", "")
params = {} # nextLink already contains encoded params.
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[ChatMessage]:
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
The ``filter_config.channels`` key is checked as a text-filter on
the channel name post-fetch (the API doesn't support channel OData
filter directly on ``getAllMessages``).
"""
cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50}
if since:
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
messages: list[ChatMessage] = []
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(messages) < _MAX_MESSAGES:
try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc:
# getAllMessages requires specific licensing; degrade gracefully.
if exc.response.status_code in (403, 404):
logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d) — "
"check Teams license or permissions",
exc.response.status_code,
)
break
raise
for item in data.get("value", []):
msg = self._parse_teams_message(item)
if channel_filter and msg.channel:
if not any(c in msg.channel.lower() for c in channel_filter):
continue
messages.append(msg)
if len(messages) >= _MAX_MESSAGES:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages
# ── Parsers ────────────────────────────────────────────────────────────
@staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)"
sender_block = item.get("from", {}) or {}
sender_addr = (
(sender_block.get("emailAddress") or {}).get("address", "unknown")
)
date_str: str = item.get("receivedDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_body: str = body_block.get("content", "")
if content_type == "html":
body_text = _strip_html(raw_body)
else:
body_text = raw_body or item.get("bodyPreview", "")
body_text = body_text[:_BODY_TRUNCATE]
return EmailMessage(
id=item.get("id", ""),
subject=subject,
sender=sender_addr,
body_text=body_text,
date=date,
)
@staticmethod
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
msg_id: str = item.get("id", "")
sender_block = (item.get("from") or {}).get("user") or {}
sender: str = sender_block.get("displayName", "unknown")
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
date_str: str = item.get("createdDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_content: str = body_block.get("content", "")
content = _strip_html(raw_content) if content_type == "html" else raw_content
content = content[:_BODY_TRUNCATE]
return ChatMessage(
id=msg_id,
content=content,
sender=sender,
channel=channel,
date=date,
)

View File

@@ -8,6 +8,9 @@ services:
required: false required: false
environment: environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
volumes:
- copilot_tokens:/root/.config/litellm/github_copilot
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
@@ -66,3 +69,4 @@ volumes:
postgres_data: postgres_data:
minio_data: minio_data:
qdrant_data: qdrant_data:
copilot_tokens:

View File

@@ -25,4 +25,10 @@ moto[s3]>=5.0.0
pinecone>=5.0.0 pinecone>=5.0.0
qdrant-client>=1.7.0 qdrant-client>=1.7.0
croniter>=3.0.0 croniter>=3.0.0
google-api-python-client>=2.130.0
google-auth>=2.29.0
google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
ruff>=0.8.0 ruff>=0.8.0

View File

@@ -455,21 +455,232 @@ async def test_run_local_agent_llm_extraction_error():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_cloud_agent_stub_returns_error(): async def test_run_cloud_agent_device_offline():
"""Cloud agent stub immediately marks run as error with informative message.""" """Cloud agent aborts immediately when no device is connected."""
config = _make_cloud_config() config = _make_cloud_config()
run_log = _make_run_log(config.id, agent_type="cloud") run_log = _make_run_log(config.id, agent_type="cloud")
mgr = DeviceConnectionManager() # empty — no devices registered
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_no_oauth_token():
"""Cloud agent errors when no OAuth token is stored."""
config = _make_cloud_config()
config.oauth_token_encrypted = None
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager() mgr = _make_manager()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_cloud_agent(_FREE_UID, config, run_log, mgr) await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once() _, kwargs = mock_finalize.call_args
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error" assert kwargs["status"] == "error"
assert len(kwargs["errors"]) == 1 assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
assert "gmail" in kwargs["errors"][0].lower()
assert "3.6" in kwargs["errors"][0]
@pytest.mark.asyncio
async def test_run_cloud_agent_token_decrypt_failure():
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
config = _make_cloud_config()
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
from cryptography.fernet import Fernet as _Fernet
valid_key = _Fernet.generate_key().decode()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("decrypt" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_happy_path_gmail():
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
from app.integrations import EmailMessage, encrypt_token
from cryptography.fernet import Fernet as _Fernet
fernet_key = _Fernet.generate_key().decode()
credentials = {
"token": "access_abc",
"refresh_token": "refresh_xyz",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "cid",
"client_secret": "csec",
}
config = _make_cloud_config()
config.provider = "gmail"
config.prompt_template = "Extract tasks from this email."
config.data_types = ["tasks"]
with patch("app.integrations.settings") as ms:
ms.OAUTH_ENCRYPTION_KEY = fernet_key
config.oauth_token_encrypted = encrypt_token(credentials)
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
sample_email = EmailMessage(
id="msg001",
subject="Action required",
sender="boss@company.com",
body_text="Please fix the bug by Friday.",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
with patch("app.integrations.settings") as mock_int_settings, \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
patch("app.core.agent_runner.async_session"):
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
mock_gmail = AsyncMock()
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
mock_gmail.refreshed_credentials = None
with patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_gmail):
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_extract.assert_called_once()
mock_insert.assert_called_once()
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "success"
assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1
assert kwargs["config_type"] == "cloud"
@pytest.mark.asyncio
async def test_run_cloud_agent_provider_fetch_error():
"""Cloud agent records error status when provider fetch raises RuntimeError."""
credentials = {"token": "abc"}
config = _make_cloud_config()
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
config.prompt_template = "Extract tasks."
config.data_types = ["tasks"]
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
mock_provider = AsyncMock()
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
mock_provider.refreshed_credentials = None
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_provider), \
patch("app.core.agent_runner.async_session"):
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_refreshed_token_persisted():
"""When the provider refreshes its token, the new ciphertext is written to DB."""
from app.integrations import EmailMessage, encrypt_token
from cryptography.fernet import Fernet as _Fernet
fernet_key = _Fernet.generate_key().decode()
credentials = {"token": "old_token", "refresh_token": "rt_old"}
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
config = _make_cloud_config()
config.prompt_template = "Extract tasks."
config.data_types = ["tasks"]
with patch("app.integrations.settings") as ms:
ms.OAUTH_ENCRYPTION_KEY = fernet_key
config.oauth_token_encrypted = encrypt_token(credentials)
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
mock_provider = AsyncMock()
mock_provider.fetch_messages = AsyncMock(return_value=[])
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
# Track DB writes via mock async_session.
mock_cfg_row = MagicMock()
mock_cfg_row.oauth_token_encrypted = None
mock_db = AsyncMock()
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
mock_db.__aexit__ = AsyncMock(return_value=False)
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
cfg_result = MagicMock()
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
mock_db.execute = AsyncMock(return_value=cfg_result)
mock_db.commit = AsyncMock()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_provider), \
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
patch("app.core.agent_runner.async_session", return_value=mock_db), \
patch("app.integrations.settings") as mock_int_settings:
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
# The new encrypted token should have been written to the config row.
mock_encrypt.assert_called_once_with(fresh_credentials)
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
@pytest.mark.asyncio
async def test_finalize_run_updates_cloud_config_last_run_at():
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
from app.core.agent_runner import _finalize_run
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
run_log.id = str(uuid.uuid4())
mock_cfg = MagicMock()
mock_cfg.last_run_at = None
cfg_result = MagicMock()
cfg_result.scalar_one_or_none.return_value = mock_cfg
mock_db = AsyncMock()
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
mock_db.__aexit__ = AsyncMock(return_value=False)
mock_db.merge = AsyncMock(return_value=run_log)
mock_db.execute = AsyncMock(return_value=cfg_result)
mock_db.commit = AsyncMock()
config_id = str(uuid.uuid4())
with patch("app.core.agent_runner.async_session", return_value=mock_db):
await _finalize_run(
run_log,
status="success",
update_config_last_run=True,
config_id=config_id,
config_type="cloud",
)
# CloudAgentConfig.last_run_at should have been set.
assert mock_cfg.last_run_at is not None
mock_db.commit.assert_called()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

729
tests/test_integrations.py Normal file
View File

@@ -0,0 +1,729 @@
"""Tests for Step 3.6: cloud provider integration clients.
Coverage:
Unit \u2014 app/integrations/__init__.py:
- encrypt_token / decrypt_token round-trip
- decrypt_token raises ValueError on invalid ciphertext
- encrypt_token raises ValueError on empty/non-dict input
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
- get_provider returns GmailClient for 'gmail'
- get_provider returns MSGraphClient for 'outlook' and 'teams'
- get_provider raises ValueError for unknown provider
Unit \u2014 app/integrations/gmail.py:
- _build_gmail_query with no filter returns empty string
- _build_gmail_query with labels builds label: expr
- _build_gmail_query with senders builds from: expr
- _build_gmail_query with date_range builds after:/before: exprs
- _build_gmail_query since overrides date_range.from when more recent
- _build_gmail_query date_range.from overrides since when more recent
- _parse_body extracts text/plain part
- _parse_body extracts text/html part (stripped)
- _parse_body recurses into multipart, prefers text/plain
- GmailClient.fetch_messages: happy path with mocked service
- GmailClient.fetch_messages: no messages returns empty list
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
- GmailClient.refreshed_credentials: None when token unchanged
- GmailClient.refreshed_credentials: returns dict when token changes
Unit \u2014 app/integrations/ms_graph.py:
- _build_email_filter with no filter returns empty string
- _build_email_filter with senders builds OData from clause
- _build_email_filter with since builds receivedDateTime ge clause
- MSGraphClient.fetch_emails: happy path with mocked httpx
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
- MSGraphClient.fetch_messages: happy path with mocked httpx
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
- MSGraphClient.refreshed_credentials: None when token unchanged
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from app.integrations import (
ChatMessage,
EmailMessage,
decrypt_token,
encrypt_token,
get_provider,
)
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# Helpers
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
# so we generate a proper one below)
from cryptography.fernet import Fernet as _Fernet # noqa: E402
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
_TOKEN_DICT = {
"token": "access_abc",
"refresh_token": "refresh_xyz",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "client_id_123",
"client_secret": "client_secret_456",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
}
_MS_TOKEN_DICT = {
"access_token": "ms_access_abc",
"refresh_token": "ms_refresh_xyz",
"token_type": "Bearer",
"scope": "Mail.Read offline_access",
}
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# encrypt_token / decrypt_token
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestTokenEncryption:
"""encrypt_token / decrypt_token round-trip tests."""
def test_round_trip(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
encrypted = encrypt_token(_TOKEN_DICT)
assert isinstance(encrypted, str)
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
recovered = decrypt_token(encrypted)
assert recovered == _TOKEN_DICT
def test_decrypt_invalid_ciphertext_raises_value_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
with pytest.raises(ValueError, match="Failed to decrypt"):
decrypt_token("this-is-not-valid-fernet-ciphertext")
def test_decrypt_wrong_key_raises_value_error(self):
"""Decrypting with a different key must fail with ValueError."""
other_key = _Fernet.generate_key().decode("utf-8")
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
encrypted = encrypt_token(_TOKEN_DICT)
with patch("app.integrations.settings") as mock_settings2:
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
with pytest.raises(ValueError, match="Failed to decrypt"):
decrypt_token(encrypted)
def test_encrypt_empty_dict_raises_value_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
with pytest.raises(ValueError, match="non-empty dict"):
encrypt_token({})
def test_encrypt_non_dict_raises_value_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
with pytest.raises(ValueError, match="non-empty dict"):
encrypt_token("not-a-dict") # type: ignore[arg-type]
def test_missing_key_raises_runtime_error(self):
with patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = ""
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
encrypt_token(_TOKEN_DICT)
def test_email_message_as_text(self):
msg = EmailMessage(
id="m1",
subject="Hello",
sender="alice@example.com",
body_text="Test body",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
text = msg.as_text
assert "From: alice@example.com" in text
assert "Subject: Hello" in text
assert "Test body" in text
def test_chat_message_as_text(self):
msg = ChatMessage(
id="c1",
content="Buy milk",
sender="bob",
channel="general",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
text = msg.as_text
assert "From: bob" in text
assert "channel: general" in text
assert "Buy milk" in text
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# get_provider factory
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestGetProvider:
def test_gmail_returns_gmail_client(self):
from app.integrations.gmail import GmailClient
client = get_provider("gmail", _TOKEN_DICT)
assert isinstance(client, GmailClient)
def test_outlook_returns_ms_graph_client(self):
from app.integrations.ms_graph import MSGraphClient
client = get_provider("outlook", _MS_TOKEN_DICT)
assert isinstance(client, MSGraphClient)
def test_teams_returns_ms_graph_client(self):
from app.integrations.ms_graph import MSGraphClient
client = get_provider("teams", _MS_TOKEN_DICT)
assert isinstance(client, MSGraphClient)
def test_unknown_provider_raises_value_error(self):
with pytest.raises(ValueError, match="Unknown cloud provider"):
get_provider("slack", {})
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# Gmail client \u2014 query builder
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestBuildGmailQuery:
"""Unit tests for gmail._build_gmail_query."""
def setup_method(self):
from app.integrations.gmail import _build_gmail_query
self._fn = _build_gmail_query
def test_empty_returns_empty_string(self):
assert self._fn(None, None) == ""
def test_single_label(self):
q = self._fn({"labels": ["INBOX"]}, None)
assert "label:INBOX" in q
def test_multiple_labels_joined_with_or(self):
q = self._fn({"labels": ["INBOX", "work"]}, None)
assert "label:INBOX OR label:work" in q
def test_senders(self):
q = self._fn({"senders": ["alice@example.com"]}, None)
assert "from:alice@example.com" in q
def test_date_range_from(self):
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
assert "after:2025/01/15" in q
def test_date_range_to(self):
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
assert "before:2025/03/01" in q
def test_since_overrides_earlier_date_range_from(self):
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
assert "after:2025/02/01" in q
assert "after:2025/01/01" not in q
def test_date_range_from_overrides_earlier_since(self):
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
assert "after:2025/02/01" in q
def test_invalid_date_ignored(self):
"""An invalid date string in filter_config must not raise, just be skipped."""
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
assert "after:" not in q
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# Gmail client \u2014 body parsing
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestParseBody:
"""Unit tests for gmail._parse_body."""
def setup_method(self):
from app.integrations.gmail import _parse_body
self._fn = _parse_body
def _encode(self, text: str) -> str:
import base64
return base64.urlsafe_b64encode(text.encode()).decode()
def test_text_plain_extracted(self):
payload = {
"mimeType": "text/plain",
"body": {"data": self._encode("Hello world")},
}
assert self._fn(payload) == "Hello world"
def test_text_html_stripped(self):
payload = {
"mimeType": "text/html",
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
}
result = self._fn(payload)
assert "Hello" in result
assert "<p>" not in result
def test_multipart_prefers_plain_over_html(self):
plain_data = self._encode("Plain text")
html_data = self._encode("<p>HTML text</p>")
payload = {
"mimeType": "multipart/alternative",
"body": {},
"parts": [
{"mimeType": "text/html", "body": {"data": html_data}},
{"mimeType": "text/plain", "body": {"data": plain_data}},
],
}
result = self._fn(payload)
assert result == "Plain text"
def test_empty_payload_returns_empty_string(self):
assert self._fn({}) == ""
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# GmailClient.fetch_messages
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
def _make_gmail_message(
msg_id: str = "msg001",
subject: str = "Test email",
sender: str = "alice@example.com",
body_text: str = "Hello world",
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
) -> dict:
"""Build a minimal Gmail API message response dict."""
import base64
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
return {
"id": msg_id,
"labelIds": ["INBOX"],
"payload": {
"mimeType": "text/plain",
"headers": [
{"name": "Subject", "value": subject},
{"name": "From", "value": sender},
{"name": "Date", "value": date},
],
"body": {"data": body_data},
},
}
class TestGmailClientFetchMessages:
"""GmailClient.fetch_messages tests with mocked Google API."""
def _make_client(self) -> "GmailClient":
from app.integrations.gmail import GmailClient
return GmailClient(_TOKEN_DICT)
@pytest.mark.asyncio
async def test_happy_path_returns_email_messages(self):
client = self._make_client()
msg = _make_gmail_message()
mock_service = MagicMock()
mock_users = mock_service.users.return_value
mock_messages = mock_users.messages.return_value
mock_messages.list.return_value.execute.return_value = {
"messages": [{"id": "msg001"}]
}
mock_messages.get.return_value.execute.return_value = msg
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
# Simulate to_thread running the sync function and returning results.
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
mock_thread.side_effect = fake_to_thread
with patch("googleapiclient.discovery.build", return_value=mock_service), \
patch("google.auth.transport.requests.Request"), \
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
results = await client.fetch_messages()
assert len(results) == 1
assert results[0].subject == "Test email"
assert results[0].sender == "alice@example.com"
assert results[0].body_text == "Hello world"
@pytest.mark.asyncio
async def test_no_messages_returns_empty_list(self):
client = self._make_client()
mock_service = MagicMock()
mock_users = mock_service.users.return_value
mock_messages = mock_users.messages.return_value
mock_messages.list.return_value.execute.return_value = {"messages": []}
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
mock_thread.side_effect = fake_to_thread
with patch("googleapiclient.discovery.build", return_value=mock_service), \
patch("google.auth.transport.requests.Request"), \
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
results = await client.fetch_messages()
assert results == []
@pytest.mark.asyncio
async def test_list_http_error_raises_runtime_error(self):
import googleapiclient.errors
client = self._make_client()
mock_service = MagicMock()
mock_users = mock_service.users.return_value
mock_messages = mock_users.messages.return_value
mock_resp = MagicMock()
mock_resp.status = 403
mock_resp.reason = "Forbidden"
mock_messages.list.return_value.execute.side_effect = (
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
)
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
mock_thread.side_effect = fake_to_thread
with patch("googleapiclient.discovery.build", return_value=mock_service), \
patch("google.auth.transport.requests.Request"), \
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
await client.fetch_messages()
def test_refreshed_credentials_none_when_unchanged(self):
client = self._make_client()
# Token unchanged — should return None.
assert client.refreshed_credentials is None
def test_refreshed_credentials_returns_dict_when_token_changes(self):
client = self._make_client()
# Simulate a token refresh by changing the access token on the credentials object.
client._credentials.token = "new_access_token_xyz"
refreshed = client.refreshed_credentials
assert refreshed is not None
assert refreshed["token"] == "new_access_token_xyz"
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# MS Graph client \u2014 email filter builder
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
class TestBuildEmailFilter:
"""Unit tests for ms_graph._build_email_filter."""
def setup_method(self):
from app.integrations.ms_graph import _build_email_filter
self._fn = _build_email_filter
def test_empty_returns_empty_string(self):
assert self._fn(None, None) == ""
def test_single_sender(self):
result = self._fn({"senders": ["alice@example.com"]}, None)
assert "from/emailAddress/address eq 'alice@example.com'" in result
def test_multiple_senders_joined_with_or(self):
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
assert " or " in result
assert "a@x.com" in result
assert "b@x.com" in result
def test_since_adds_received_date_ge_clause(self):
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
result = self._fn(None, since)
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
def test_date_range_to_adds_received_date_le_clause(self):
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
assert "receivedDateTime le" in result
def test_since_overrides_earlier_date_range_from(self):
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
assert "2025-02-01T00:00:00Z" in result
assert "2025-01-01" not in result
def test_invalid_date_ignored(self):
result = self._fn({"date_range": {"from": "bad-date"}}, None)
assert "receivedDateTime" not in result
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
# MSGraphClient.fetch_emails
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
def _make_graph_email(
msg_id: str = "email001",
subject: str = "Meeting tomorrow",
sender_address: str = "boss@company.com",
body_content: str = "Please prepare the report.",
received: str = "2025-06-01T10:00:00Z",
) -> dict:
"""Build a minimal MS Graph message item dict."""
return {
"id": msg_id,
"subject": subject,
"from": {"emailAddress": {"address": sender_address}},
"receivedDateTime": received,
"body": {"contentType": "text", "content": body_content},
"bodyPreview": body_content[:100],
}
def _make_graph_teams_message(
msg_id: str = "teams001",
content: str = "Stand-up at 9am",
sender: str = "alice",
channel_id: str = "chan001",
created: str = "2025-06-01T08:00:00Z",
) -> dict:
return {
"id": msg_id,
"body": {"contentType": "text", "content": content},
"from": {"user": {"displayName": sender}},
"channelIdentity": {"channelId": channel_id},
"createdDateTime": created,
}
class TestMSGraphClientFetchEmails:
"""MSGraphClient.fetch_emails tests with mocked httpx."""
def _make_client(self) -> "MSGraphClient":
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT)
@pytest.mark.asyncio
async def test_happy_path_returns_email_messages(self):
client = self._make_client()
graph_email = _make_graph_email()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": [graph_email]}
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_emails()
assert len(results) == 1
assert results[0].subject == "Meeting tomorrow"
assert results[0].sender == "boss@company.com"
assert results[0].body_text == "Please prepare the report."
@pytest.mark.asyncio
async def test_pagination_stops_at_max_emails(self):
"""No nextLink in first page \u2014 only one batch returned."""
client = self._make_client()
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_emails()
assert len(results) == 3
@pytest.mark.asyncio
async def test_401_triggers_token_refresh_and_retries(self):
"""On first 401, token refresh is attempted and the request retried."""
from app.integrations.ms_graph import MSGraphClient
client = MSGraphClient(_MS_TOKEN_DICT)
graph_email = _make_graph_email()
response_401 = MagicMock()
response_401.status_code = 401
response_200 = MagicMock()
response_200.status_code = 200
response_200.json.return_value = {"value": [graph_email]}
response_200.raise_for_status = MagicMock()
call_count = 0
async def fake_get(url, params=None, headers=None):
nonlocal call_count
call_count += 1
if call_count == 1:
return response_401
return response_200
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
mock_http = AsyncMock()
mock_http.get = fake_get
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_emails()
mock_refresh.assert_called_once()
assert len(results) == 1
def test_refreshed_credentials_none_when_token_unchanged(self):
client = self._make_client()
assert client.refreshed_credentials is None
def test_refreshed_credentials_returns_dict_when_token_changes(self):
client = self._make_client()
client._access_token = "new_token_abc"
assert client.refreshed_credentials is not None
assert client.refreshed_credentials["access_token"] == "new_token_abc"
class TestMSGraphClientFetchMessages:
"""MSGraphClient.fetch_messages (Teams) tests."""
def _make_client(self) -> "MSGraphClient":
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT)
@pytest.mark.asyncio
async def test_happy_path_returns_chat_messages(self):
client = self._make_client()
teams_msg = _make_graph_teams_message()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": [teams_msg]}
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_messages()
assert len(results) == 1
assert results[0].content == "Stand-up at 9am"
assert results[0].sender == "alice"
@pytest.mark.asyncio
async def test_403_degrades_gracefully(self):
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
import httpx as _httpx
client = self._make_client()
error_response = MagicMock()
error_response.status_code = 403
http_error = _httpx.HTTPStatusError(
"Forbidden", request=MagicMock(), response=error_response
)
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(side_effect=http_error)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_messages()
assert results == []
@pytest.mark.asyncio
async def test_channel_filter_applied(self):
"""Messages from non-matching channels are filtered out."""
client = self._make_client()
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"value": [matching, non_matching]}
mock_response.raise_for_status = MagicMock()
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
results = await client.fetch_messages(
filter_config={"channels": ["dev-channel"]}
)
assert len(results) == 1
assert results[0].content == "Deploy today"
class TestMSGraphClientRefreshToken:
"""MSGraphClient._refresh_access_token with mocked MSAL."""
@pytest.mark.asyncio
async def test_msal_error_raises_runtime_error(self):
from app.integrations.ms_graph import MSGraphClient
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
mock_app = MagicMock()
mock_app.acquire_token_by_refresh_token.return_value = {
"error": "invalid_grant",
"error_description": "Refresh token expired",
}
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
patch("app.integrations.ms_graph.settings") as mock_settings:
mock_settings.MS_CLIENT_ID = "client_id"
mock_settings.MS_CLIENT_SECRET = "secret"
mock_settings.MS_TENANT_ID = "common"
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
await client._refresh_access_token()
@pytest.mark.asyncio
async def test_successful_refresh_updates_access_token(self):
from app.integrations.ms_graph import MSGraphClient
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
mock_app = MagicMock()
mock_app.acquire_token_by_refresh_token.return_value = {
"access_token": "new_access_token",
"refresh_token": "new_refresh_token",
}
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
patch("app.integrations.ms_graph.settings") as mock_settings:
mock_settings.MS_CLIENT_ID = "client_id"
mock_settings.MS_CLIENT_SECRET = "secret"
mock_settings.MS_TENANT_ID = "common"
await client._refresh_access_token()
assert client._access_token == "new_access_token"
assert client._refresh_token == "new_refresh_token"