728 lines
34 KiB
Python
728 lines
34 KiB
Python
"""Tests for Step 3.6: cloud provider integration clients.
|
|
|
|
Coverage:
|
|
Unit \u2014 app/integrations/__init__.py:
|
|
- encrypt_token / decrypt_token round-trip
|
|
- decrypt_token raises ValueError on invalid ciphertext
|
|
- encrypt_token raises ValueError on empty/non-dict input
|
|
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
|
|
- get_provider returns GmailClient for 'gmail'
|
|
- get_provider returns MSGraphClient for 'outlook' and 'teams'
|
|
- get_provider raises ValueError for unknown provider
|
|
|
|
Unit \u2014 app/integrations/gmail.py:
|
|
- _build_gmail_query with no filter returns empty string
|
|
- _build_gmail_query with labels builds label: expr
|
|
- _build_gmail_query with senders builds from: expr
|
|
- _build_gmail_query with date_range builds after:/before: exprs
|
|
- _build_gmail_query since overrides date_range.from when more recent
|
|
- _build_gmail_query date_range.from overrides since when more recent
|
|
- _parse_body extracts text/plain part
|
|
- _parse_body extracts text/html part (stripped)
|
|
- _parse_body recurses into multipart, prefers text/plain
|
|
- GmailClient.fetch_messages: happy path with mocked service
|
|
- GmailClient.fetch_messages: no messages returns empty list
|
|
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
|
|
- GmailClient.refreshed_credentials: None when token unchanged
|
|
- GmailClient.refreshed_credentials: returns dict when token changes
|
|
|
|
Unit \u2014 app/integrations/ms_graph.py:
|
|
- _build_email_filter with no filter returns empty string
|
|
- _build_email_filter with senders builds OData from clause
|
|
- _build_email_filter with since builds receivedDateTime ge clause
|
|
- MSGraphClient.fetch_emails: happy path with mocked httpx
|
|
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
|
|
- MSGraphClient.fetch_messages: happy path with mocked httpx
|
|
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
|
|
- MSGraphClient.refreshed_credentials: None when token unchanged
|
|
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.integrations import (
|
|
ChatMessage,
|
|
EmailMessage,
|
|
decrypt_token,
|
|
encrypt_token,
|
|
get_provider,
|
|
)
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# Helpers
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
|
|
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
|
|
# so we generate a proper one below)
|
|
|
|
from cryptography.fernet import Fernet as _Fernet # noqa: E402
|
|
|
|
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
|
|
|
|
_TOKEN_DICT = {
|
|
"token": "access_abc",
|
|
"refresh_token": "refresh_xyz",
|
|
"token_uri": "https://oauth2.googleapis.com/token",
|
|
"client_id": "client_id_123",
|
|
"client_secret": "client_secret_456",
|
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
|
}
|
|
|
|
_MS_TOKEN_DICT = {
|
|
"access_token": "ms_access_abc",
|
|
"refresh_token": "ms_refresh_xyz",
|
|
"token_type": "Bearer",
|
|
"scope": "Mail.Read offline_access",
|
|
}
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# encrypt_token / decrypt_token
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
class TestTokenEncryption:
|
|
"""encrypt_token / decrypt_token round-trip tests."""
|
|
|
|
def test_round_trip(self):
|
|
with patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
|
encrypted = encrypt_token(_TOKEN_DICT)
|
|
assert isinstance(encrypted, str)
|
|
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
|
|
recovered = decrypt_token(encrypted)
|
|
assert recovered == _TOKEN_DICT
|
|
|
|
def test_decrypt_invalid_ciphertext_raises_value_error(self):
|
|
with patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
|
decrypt_token("this-is-not-valid-fernet-ciphertext")
|
|
|
|
def test_decrypt_wrong_key_raises_value_error(self):
|
|
"""Decrypting with a different key must fail with ValueError."""
|
|
other_key = _Fernet.generate_key().decode("utf-8")
|
|
with patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
|
encrypted = encrypt_token(_TOKEN_DICT)
|
|
with patch("app.integrations.settings") as mock_settings2:
|
|
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
|
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
|
decrypt_token(encrypted)
|
|
|
|
def test_encrypt_empty_dict_raises_value_error(self):
|
|
with patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
|
with pytest.raises(ValueError, match="non-empty dict"):
|
|
encrypt_token({})
|
|
|
|
def test_encrypt_non_dict_raises_value_error(self):
|
|
with patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
|
with pytest.raises(ValueError, match="non-empty dict"):
|
|
encrypt_token("not-a-dict") # type: ignore[arg-type]
|
|
|
|
def test_missing_key_raises_runtime_error(self):
|
|
with patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = ""
|
|
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
|
|
encrypt_token(_TOKEN_DICT)
|
|
|
|
def test_email_message_as_text(self):
|
|
msg = EmailMessage(
|
|
id="m1",
|
|
subject="Hello",
|
|
sender="alice@example.com",
|
|
body_text="Test body",
|
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
|
)
|
|
text = msg.as_text
|
|
assert "From: alice@example.com" in text
|
|
assert "Subject: Hello" in text
|
|
assert "Test body" in text
|
|
|
|
def test_chat_message_as_text(self):
|
|
msg = ChatMessage(
|
|
id="c1",
|
|
content="Buy milk",
|
|
sender="bob",
|
|
channel="general",
|
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
|
)
|
|
text = msg.as_text
|
|
assert "From: bob" in text
|
|
assert "channel: general" in text
|
|
assert "Buy milk" in text
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# get_provider factory
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
class TestGetProvider:
|
|
def test_gmail_returns_gmail_client(self):
|
|
from app.integrations.gmail import GmailClient
|
|
|
|
client = get_provider("gmail", _TOKEN_DICT)
|
|
assert isinstance(client, GmailClient)
|
|
|
|
def test_outlook_returns_ms_graph_client(self):
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
|
|
client = get_provider("outlook", _MS_TOKEN_DICT)
|
|
assert isinstance(client, MSGraphClient)
|
|
|
|
def test_teams_returns_ms_graph_client(self):
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
|
|
client = get_provider("teams", _MS_TOKEN_DICT)
|
|
assert isinstance(client, MSGraphClient)
|
|
|
|
def test_unknown_provider_raises_value_error(self):
|
|
with pytest.raises(ValueError, match="Unknown cloud provider"):
|
|
get_provider("slack", {})
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# Gmail client \u2014 query builder
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
class TestBuildGmailQuery:
|
|
"""Unit tests for gmail._build_gmail_query."""
|
|
|
|
def setup_method(self):
|
|
from app.integrations.gmail import _build_gmail_query
|
|
self._fn = _build_gmail_query
|
|
|
|
def test_empty_returns_empty_string(self):
|
|
assert self._fn(None, None) == ""
|
|
|
|
def test_single_label(self):
|
|
q = self._fn({"labels": ["INBOX"]}, None)
|
|
assert "label:INBOX" in q
|
|
|
|
def test_multiple_labels_joined_with_or(self):
|
|
q = self._fn({"labels": ["INBOX", "work"]}, None)
|
|
assert "label:INBOX OR label:work" in q
|
|
|
|
def test_senders(self):
|
|
q = self._fn({"senders": ["alice@example.com"]}, None)
|
|
assert "from:alice@example.com" in q
|
|
|
|
def test_date_range_from(self):
|
|
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
|
|
assert "after:2025/01/15" in q
|
|
|
|
def test_date_range_to(self):
|
|
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
|
|
assert "before:2025/03/01" in q
|
|
|
|
def test_since_overrides_earlier_date_range_from(self):
|
|
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
|
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
|
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
|
assert "after:2025/02/01" in q
|
|
assert "after:2025/01/01" not in q
|
|
|
|
def test_date_range_from_overrides_earlier_since(self):
|
|
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
|
|
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
|
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
|
|
assert "after:2025/02/01" in q
|
|
|
|
def test_invalid_date_ignored(self):
|
|
"""An invalid date string in filter_config must not raise, just be skipped."""
|
|
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
|
|
assert "after:" not in q
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# Gmail client \u2014 body parsing
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
class TestParseBody:
|
|
"""Unit tests for gmail._parse_body."""
|
|
|
|
def setup_method(self):
|
|
from app.integrations.gmail import _parse_body
|
|
self._fn = _parse_body
|
|
|
|
def _encode(self, text: str) -> str:
|
|
import base64
|
|
return base64.urlsafe_b64encode(text.encode()).decode()
|
|
|
|
def test_text_plain_extracted(self):
|
|
payload = {
|
|
"mimeType": "text/plain",
|
|
"body": {"data": self._encode("Hello world")},
|
|
}
|
|
assert self._fn(payload) == "Hello world"
|
|
|
|
def test_text_html_stripped(self):
|
|
payload = {
|
|
"mimeType": "text/html",
|
|
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
|
|
}
|
|
result = self._fn(payload)
|
|
assert "Hello" in result
|
|
assert "<p>" not in result
|
|
|
|
def test_multipart_prefers_plain_over_html(self):
|
|
plain_data = self._encode("Plain text")
|
|
html_data = self._encode("<p>HTML text</p>")
|
|
payload = {
|
|
"mimeType": "multipart/alternative",
|
|
"body": {},
|
|
"parts": [
|
|
{"mimeType": "text/html", "body": {"data": html_data}},
|
|
{"mimeType": "text/plain", "body": {"data": plain_data}},
|
|
],
|
|
}
|
|
result = self._fn(payload)
|
|
assert result == "Plain text"
|
|
|
|
def test_empty_payload_returns_empty_string(self):
|
|
assert self._fn({}) == ""
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# GmailClient.fetch_messages
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
def _make_gmail_message(
|
|
msg_id: str = "msg001",
|
|
subject: str = "Test email",
|
|
sender: str = "alice@example.com",
|
|
body_text: str = "Hello world",
|
|
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
|
|
) -> dict:
|
|
"""Build a minimal Gmail API message response dict."""
|
|
import base64
|
|
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
|
|
return {
|
|
"id": msg_id,
|
|
"labelIds": ["INBOX"],
|
|
"payload": {
|
|
"mimeType": "text/plain",
|
|
"headers": [
|
|
{"name": "Subject", "value": subject},
|
|
{"name": "From", "value": sender},
|
|
{"name": "Date", "value": date},
|
|
],
|
|
"body": {"data": body_data},
|
|
},
|
|
}
|
|
|
|
|
|
class TestGmailClientFetchMessages:
|
|
"""GmailClient.fetch_messages tests with mocked Google API."""
|
|
|
|
def _make_client(self):
|
|
from app.integrations.gmail import GmailClient
|
|
return GmailClient(_TOKEN_DICT)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_happy_path_returns_email_messages(self):
|
|
client = self._make_client()
|
|
msg = _make_gmail_message()
|
|
|
|
mock_service = MagicMock()
|
|
mock_users = mock_service.users.return_value
|
|
mock_messages = mock_users.messages.return_value
|
|
mock_messages.list.return_value.execute.return_value = {
|
|
"messages": [{"id": "msg001"}]
|
|
}
|
|
mock_messages.get.return_value.execute.return_value = msg
|
|
|
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
|
# Simulate to_thread running the sync function and returning results.
|
|
async def fake_to_thread(fn, *args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
mock_thread.side_effect = fake_to_thread
|
|
|
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
|
patch("google.auth.transport.requests.Request"), \
|
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
|
results = await client.fetch_messages()
|
|
|
|
assert len(results) == 1
|
|
assert results[0].subject == "Test email"
|
|
assert results[0].sender == "alice@example.com"
|
|
assert results[0].body_text == "Hello world"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_messages_returns_empty_list(self):
|
|
client = self._make_client()
|
|
|
|
mock_service = MagicMock()
|
|
mock_users = mock_service.users.return_value
|
|
mock_messages = mock_users.messages.return_value
|
|
mock_messages.list.return_value.execute.return_value = {"messages": []}
|
|
|
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
|
async def fake_to_thread(fn, *args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
mock_thread.side_effect = fake_to_thread
|
|
|
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
|
patch("google.auth.transport.requests.Request"), \
|
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
|
results = await client.fetch_messages()
|
|
|
|
assert results == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_http_error_raises_runtime_error(self):
|
|
import googleapiclient.errors
|
|
client = self._make_client()
|
|
|
|
mock_service = MagicMock()
|
|
mock_users = mock_service.users.return_value
|
|
mock_messages = mock_users.messages.return_value
|
|
mock_resp = MagicMock()
|
|
mock_resp.status = 403
|
|
mock_resp.reason = "Forbidden"
|
|
mock_messages.list.return_value.execute.side_effect = (
|
|
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
|
|
)
|
|
|
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
|
async def fake_to_thread(fn, *args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
mock_thread.side_effect = fake_to_thread
|
|
|
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
|
patch("google.auth.transport.requests.Request"), \
|
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
|
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
|
|
await client.fetch_messages()
|
|
|
|
def test_refreshed_credentials_none_when_unchanged(self):
|
|
client = self._make_client()
|
|
# Token unchanged — should return None.
|
|
assert client.refreshed_credentials is None
|
|
|
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
|
client = self._make_client()
|
|
# Simulate a token refresh by changing the access token on the credentials object.
|
|
client._credentials.token = "new_access_token_xyz"
|
|
refreshed = client.refreshed_credentials
|
|
assert refreshed is not None
|
|
assert refreshed["token"] == "new_access_token_xyz"
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# MS Graph client \u2014 email filter builder
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
class TestBuildEmailFilter:
|
|
"""Unit tests for ms_graph._build_email_filter."""
|
|
|
|
def setup_method(self):
|
|
from app.integrations.ms_graph import _build_email_filter
|
|
self._fn = _build_email_filter
|
|
|
|
def test_empty_returns_empty_string(self):
|
|
assert self._fn(None, None) == ""
|
|
|
|
def test_single_sender(self):
|
|
result = self._fn({"senders": ["alice@example.com"]}, None)
|
|
assert "from/emailAddress/address eq 'alice@example.com'" in result
|
|
|
|
def test_multiple_senders_joined_with_or(self):
|
|
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
|
|
assert " or " in result
|
|
assert "a@x.com" in result
|
|
assert "b@x.com" in result
|
|
|
|
def test_since_adds_received_date_ge_clause(self):
|
|
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
|
|
result = self._fn(None, since)
|
|
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
|
|
|
|
def test_date_range_to_adds_received_date_le_clause(self):
|
|
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
|
|
assert "receivedDateTime le" in result
|
|
|
|
def test_since_overrides_earlier_date_range_from(self):
|
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
|
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
|
assert "2025-02-01T00:00:00Z" in result
|
|
assert "2025-01-01" not in result
|
|
|
|
def test_invalid_date_ignored(self):
|
|
result = self._fn({"date_range": {"from": "bad-date"}}, None)
|
|
assert "receivedDateTime" not in result
|
|
|
|
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
# MSGraphClient.fetch_emails
|
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
|
|
|
|
|
def _make_graph_email(
|
|
msg_id: str = "email001",
|
|
subject: str = "Meeting tomorrow",
|
|
sender_address: str = "boss@company.com",
|
|
body_content: str = "Please prepare the report.",
|
|
received: str = "2025-06-01T10:00:00Z",
|
|
) -> dict:
|
|
"""Build a minimal MS Graph message item dict."""
|
|
return {
|
|
"id": msg_id,
|
|
"subject": subject,
|
|
"from": {"emailAddress": {"address": sender_address}},
|
|
"receivedDateTime": received,
|
|
"body": {"contentType": "text", "content": body_content},
|
|
"bodyPreview": body_content[:100],
|
|
}
|
|
|
|
|
|
def _make_graph_teams_message(
|
|
msg_id: str = "teams001",
|
|
content: str = "Stand-up at 9am",
|
|
sender: str = "alice",
|
|
channel_id: str = "chan001",
|
|
created: str = "2025-06-01T08:00:00Z",
|
|
) -> dict:
|
|
return {
|
|
"id": msg_id,
|
|
"body": {"contentType": "text", "content": content},
|
|
"from": {"user": {"displayName": sender}},
|
|
"channelIdentity": {"channelId": channel_id},
|
|
"createdDateTime": created,
|
|
}
|
|
|
|
|
|
class TestMSGraphClientFetchEmails:
|
|
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
|
|
|
def _make_client(self):
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
return MSGraphClient(_MS_TOKEN_DICT)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_happy_path_returns_email_messages(self):
|
|
client = self._make_client()
|
|
graph_email = _make_graph_email()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"value": [graph_email]}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_response)
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
results = await client.fetch_emails()
|
|
|
|
assert len(results) == 1
|
|
assert results[0].subject == "Meeting tomorrow"
|
|
assert results[0].sender == "boss@company.com"
|
|
assert results[0].body_text == "Please prepare the report."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pagination_stops_at_max_emails(self):
|
|
"""No nextLink in first page \u2014 only one batch returned."""
|
|
client = self._make_client()
|
|
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_response)
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
results = await client.fetch_emails()
|
|
|
|
assert len(results) == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_401_triggers_token_refresh_and_retries(self):
|
|
"""On first 401, token refresh is attempted and the request retried."""
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
client = MSGraphClient(_MS_TOKEN_DICT)
|
|
|
|
graph_email = _make_graph_email()
|
|
|
|
response_401 = MagicMock()
|
|
response_401.status_code = 401
|
|
|
|
response_200 = MagicMock()
|
|
response_200.status_code = 200
|
|
response_200.json.return_value = {"value": [graph_email]}
|
|
response_200.raise_for_status = MagicMock()
|
|
|
|
call_count = 0
|
|
|
|
async def fake_get(url, params=None, headers=None):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return response_401
|
|
return response_200
|
|
|
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
|
|
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = fake_get
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
results = await client.fetch_emails()
|
|
|
|
mock_refresh.assert_called_once()
|
|
assert len(results) == 1
|
|
|
|
def test_refreshed_credentials_none_when_token_unchanged(self):
|
|
client = self._make_client()
|
|
assert client.refreshed_credentials is None
|
|
|
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
|
client = self._make_client()
|
|
client._access_token = "new_token_abc"
|
|
assert client.refreshed_credentials is not None
|
|
assert client.refreshed_credentials["access_token"] == "new_token_abc"
|
|
|
|
|
|
class TestMSGraphClientFetchMessages:
|
|
"""MSGraphClient.fetch_messages (Teams) tests."""
|
|
|
|
def _make_client(self):
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
return MSGraphClient(_MS_TOKEN_DICT)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_happy_path_returns_chat_messages(self):
|
|
client = self._make_client()
|
|
teams_msg = _make_graph_teams_message()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"value": [teams_msg]}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_response)
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
results = await client.fetch_messages()
|
|
|
|
assert len(results) == 1
|
|
assert results[0].content == "Stand-up at 9am"
|
|
assert results[0].sender == "alice"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_403_degrades_gracefully(self):
|
|
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
|
|
import httpx as _httpx
|
|
|
|
client = self._make_client()
|
|
|
|
error_response = MagicMock()
|
|
error_response.status_code = 403
|
|
http_error = _httpx.HTTPStatusError(
|
|
"Forbidden", request=MagicMock(), response=error_response
|
|
)
|
|
|
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(side_effect=http_error)
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
results = await client.fetch_messages()
|
|
|
|
assert results == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_channel_filter_applied(self):
|
|
"""Messages from non-matching channels are filtered out."""
|
|
client = self._make_client()
|
|
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
|
|
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"value": [matching, non_matching]}
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_response)
|
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
results = await client.fetch_messages(
|
|
filter_config={"channels": ["dev-channel"]}
|
|
)
|
|
|
|
assert len(results) == 1
|
|
assert results[0].content == "Deploy today"
|
|
|
|
|
|
class TestMSGraphClientRefreshToken:
|
|
"""MSGraphClient._refresh_access_token with mocked MSAL."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_msal_error_raises_runtime_error(self):
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
|
"error": "invalid_grant",
|
|
"error_description": "Refresh token expired",
|
|
}
|
|
|
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
|
mock_settings.MS_CLIENT_ID = "client_id"
|
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
|
mock_settings.MS_TENANT_ID = "common"
|
|
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
|
|
await client._refresh_access_token()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_refresh_updates_access_token(self):
|
|
from app.integrations.ms_graph import MSGraphClient
|
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
|
|
|
|
mock_app = MagicMock()
|
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
|
"access_token": "new_access_token",
|
|
"refresh_token": "new_refresh_token",
|
|
}
|
|
|
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
|
mock_settings.MS_CLIENT_ID = "client_id"
|
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
|
mock_settings.MS_TENANT_ID = "common"
|
|
await client._refresh_access_token()
|
|
|
|
assert client._access_token == "new_access_token"
|
|
assert client._refresh_token == "new_refresh_token"
|