feat(step-3.6): cloud provider integrations (Gmail, Outlook, Teams)
- Add app/integrations/__init__.py: Fernet token encryption helpers, EmailMessage/ChatMessage dataclasses, get_provider() factory - Add app/integrations/gmail.py: GmailClient with async fetch_messages(), token refresh, configurable label/sender/date filters - Add app/integrations/ms_graph.py: MSGraphClient with fetch_emails() (Outlook) and fetch_messages() (Teams), MSAL token refresh, OData filters - Update app/core/agent_runner.py: replace run_cloud_agent() stub with full 8-step implementation; extend _finalize_run() for cloud config type - Update app/config/settings.py: add OAuth + Fernet encryption settings - Update requirements.txt: google-api-python-client, google-auth-*, msal, cryptography - Add tests/test_integrations.py: 47 tests covering all integration code - Update tests/test_agent_runner.py: replace stub test with 7 real tests All 76 new/updated tests pass.
This commit is contained in:
@@ -455,21 +455,232 @@ async def test_run_local_agent_llm_extraction_error():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_stub_returns_error():
|
||||
"""Cloud agent stub immediately marks run as error with informative message."""
|
||||
async def test_run_cloud_agent_device_offline():
|
||||
"""Cloud agent aborts immediately when no device is connected."""
|
||||
config = _make_cloud_config()
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_no_oauth_token():
|
||||
"""Cloud agent errors when no OAuth token is stored."""
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = None
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert len(kwargs["errors"]) == 1
|
||||
assert "gmail" in kwargs["errors"][0].lower()
|
||||
assert "3.6" in kwargs["errors"][0]
|
||||
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_token_decrypt_failure():
|
||||
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
valid_key = _Fernet.generate_key().decode()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_happy_path_gmail():
|
||||
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||
from app.integrations import EmailMessage, encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
credentials = {
|
||||
"token": "access_abc",
|
||||
"refresh_token": "refresh_xyz",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "cid",
|
||||
"client_secret": "csec",
|
||||
}
|
||||
|
||||
config = _make_cloud_config()
|
||||
config.provider = "gmail"
|
||||
config.prompt_template = "Extract tasks from this email."
|
||||
config.data_types = ["tasks"]
|
||||
|
||||
with patch("app.integrations.settings") as ms:
|
||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
sample_email = EmailMessage(
|
||||
id="msg001",
|
||||
subject="Action required",
|
||||
sender="boss@company.com",
|
||||
body_text="Please fix the bug by Friday.",
|
||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||
|
||||
with patch("app.integrations.settings") as mock_int_settings, \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||
patch("app.core.agent_runner.async_session"):
|
||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
|
||||
mock_gmail = AsyncMock()
|
||||
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||
mock_gmail.refreshed_credentials = None
|
||||
|
||||
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
mock_insert.assert_called_once()
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "success"
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["config_type"] == "cloud"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_provider_fetch_error():
|
||||
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||
credentials = {"token": "abc"}
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||
config.prompt_template = "Extract tasks."
|
||||
config.data_types = ["tasks"]
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||
mock_provider.refreshed_credentials = None
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||
patch("app.core.agent_runner.async_session"):
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||
from app.integrations import EmailMessage, encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||
|
||||
config = _make_cloud_config()
|
||||
config.prompt_template = "Extract tasks."
|
||||
config.data_types = ["tasks"]
|
||||
|
||||
with patch("app.integrations.settings") as ms:
|
||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||
|
||||
# Track DB writes via mock async_session.
|
||||
mock_cfg_row = MagicMock()
|
||||
mock_cfg_row.oauth_token_encrypted = None
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||
cfg_result = MagicMock()
|
||||
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||
patch("app.integrations.settings") as mock_int_settings:
|
||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
# The new encrypted token should have been written to the config row.
|
||||
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||
from app.core.agent_runner import _finalize_run
|
||||
|
||||
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||
run_log.id = str(uuid.uuid4())
|
||||
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.last_run_at = None
|
||||
|
||||
cfg_result = MagicMock()
|
||||
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_db.merge = AsyncMock(return_value=run_log)
|
||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
config_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="success",
|
||||
update_config_last_run=True,
|
||||
config_id=config_id,
|
||||
config_type="cloud",
|
||||
)
|
||||
|
||||
# CloudAgentConfig.last_run_at should have been set.
|
||||
assert mock_cfg.last_run_at is not None
|
||||
mock_db.commit.assert_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
729
tests/test_integrations.py
Normal file
729
tests/test_integrations.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""Tests for Step 3.6: cloud provider integration clients.
|
||||
|
||||
Coverage:
|
||||
Unit \u2014 app/integrations/__init__.py:
|
||||
- encrypt_token / decrypt_token round-trip
|
||||
- decrypt_token raises ValueError on invalid ciphertext
|
||||
- encrypt_token raises ValueError on empty/non-dict input
|
||||
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
|
||||
- get_provider returns GmailClient for 'gmail'
|
||||
- get_provider returns MSGraphClient for 'outlook' and 'teams'
|
||||
- get_provider raises ValueError for unknown provider
|
||||
|
||||
Unit \u2014 app/integrations/gmail.py:
|
||||
- _build_gmail_query with no filter returns empty string
|
||||
- _build_gmail_query with labels builds label: expr
|
||||
- _build_gmail_query with senders builds from: expr
|
||||
- _build_gmail_query with date_range builds after:/before: exprs
|
||||
- _build_gmail_query since overrides date_range.from when more recent
|
||||
- _build_gmail_query date_range.from overrides since when more recent
|
||||
- _parse_body extracts text/plain part
|
||||
- _parse_body extracts text/html part (stripped)
|
||||
- _parse_body recurses into multipart, prefers text/plain
|
||||
- GmailClient.fetch_messages: happy path with mocked service
|
||||
- GmailClient.fetch_messages: no messages returns empty list
|
||||
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
|
||||
- GmailClient.refreshed_credentials: None when token unchanged
|
||||
- GmailClient.refreshed_credentials: returns dict when token changes
|
||||
|
||||
Unit \u2014 app/integrations/ms_graph.py:
|
||||
- _build_email_filter with no filter returns empty string
|
||||
- _build_email_filter with senders builds OData from clause
|
||||
- _build_email_filter with since builds receivedDateTime ge clause
|
||||
- MSGraphClient.fetch_emails: happy path with mocked httpx
|
||||
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
|
||||
- MSGraphClient.fetch_messages: happy path with mocked httpx
|
||||
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
|
||||
- MSGraphClient.refreshed_credentials: None when token unchanged
|
||||
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.integrations import (
|
||||
ChatMessage,
|
||||
EmailMessage,
|
||||
decrypt_token,
|
||||
encrypt_token,
|
||||
get_provider,
|
||||
)
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# Helpers
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
|
||||
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
|
||||
# so we generate a proper one below)
|
||||
|
||||
from cryptography.fernet import Fernet as _Fernet # noqa: E402
|
||||
|
||||
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
|
||||
|
||||
_TOKEN_DICT = {
|
||||
"token": "access_abc",
|
||||
"refresh_token": "refresh_xyz",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "client_id_123",
|
||||
"client_secret": "client_secret_456",
|
||||
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
}
|
||||
|
||||
_MS_TOKEN_DICT = {
|
||||
"access_token": "ms_access_abc",
|
||||
"refresh_token": "ms_refresh_xyz",
|
||||
"token_type": "Bearer",
|
||||
"scope": "Mail.Read offline_access",
|
||||
}
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# encrypt_token / decrypt_token
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
class TestTokenEncryption:
|
||||
"""encrypt_token / decrypt_token round-trip tests."""
|
||||
|
||||
def test_round_trip(self):
|
||||
with patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||
encrypted = encrypt_token(_TOKEN_DICT)
|
||||
assert isinstance(encrypted, str)
|
||||
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
|
||||
recovered = decrypt_token(encrypted)
|
||||
assert recovered == _TOKEN_DICT
|
||||
|
||||
def test_decrypt_invalid_ciphertext_raises_value_error(self):
|
||||
with patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||
decrypt_token("this-is-not-valid-fernet-ciphertext")
|
||||
|
||||
def test_decrypt_wrong_key_raises_value_error(self):
|
||||
"""Decrypting with a different key must fail with ValueError."""
|
||||
other_key = _Fernet.generate_key().decode("utf-8")
|
||||
with patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||
encrypted = encrypt_token(_TOKEN_DICT)
|
||||
with patch("app.integrations.settings") as mock_settings2:
|
||||
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
|
||||
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||
decrypt_token(encrypted)
|
||||
|
||||
def test_encrypt_empty_dict_raises_value_error(self):
|
||||
with patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||
with pytest.raises(ValueError, match="non-empty dict"):
|
||||
encrypt_token({})
|
||||
|
||||
def test_encrypt_non_dict_raises_value_error(self):
|
||||
with patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||
with pytest.raises(ValueError, match="non-empty dict"):
|
||||
encrypt_token("not-a-dict") # type: ignore[arg-type]
|
||||
|
||||
def test_missing_key_raises_runtime_error(self):
|
||||
with patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = ""
|
||||
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
|
||||
encrypt_token(_TOKEN_DICT)
|
||||
|
||||
def test_email_message_as_text(self):
|
||||
msg = EmailMessage(
|
||||
id="m1",
|
||||
subject="Hello",
|
||||
sender="alice@example.com",
|
||||
body_text="Test body",
|
||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
text = msg.as_text
|
||||
assert "From: alice@example.com" in text
|
||||
assert "Subject: Hello" in text
|
||||
assert "Test body" in text
|
||||
|
||||
def test_chat_message_as_text(self):
|
||||
msg = ChatMessage(
|
||||
id="c1",
|
||||
content="Buy milk",
|
||||
sender="bob",
|
||||
channel="general",
|
||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
text = msg.as_text
|
||||
assert "From: bob" in text
|
||||
assert "channel: general" in text
|
||||
assert "Buy milk" in text
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# get_provider factory
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
class TestGetProvider:
|
||||
def test_gmail_returns_gmail_client(self):
|
||||
from app.integrations.gmail import GmailClient
|
||||
|
||||
client = get_provider("gmail", _TOKEN_DICT)
|
||||
assert isinstance(client, GmailClient)
|
||||
|
||||
def test_outlook_returns_ms_graph_client(self):
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
|
||||
client = get_provider("outlook", _MS_TOKEN_DICT)
|
||||
assert isinstance(client, MSGraphClient)
|
||||
|
||||
def test_teams_returns_ms_graph_client(self):
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
|
||||
client = get_provider("teams", _MS_TOKEN_DICT)
|
||||
assert isinstance(client, MSGraphClient)
|
||||
|
||||
def test_unknown_provider_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Unknown cloud provider"):
|
||||
get_provider("slack", {})
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# Gmail client \u2014 query builder
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
class TestBuildGmailQuery:
|
||||
"""Unit tests for gmail._build_gmail_query."""
|
||||
|
||||
def setup_method(self):
|
||||
from app.integrations.gmail import _build_gmail_query
|
||||
self._fn = _build_gmail_query
|
||||
|
||||
def test_empty_returns_empty_string(self):
|
||||
assert self._fn(None, None) == ""
|
||||
|
||||
def test_single_label(self):
|
||||
q = self._fn({"labels": ["INBOX"]}, None)
|
||||
assert "label:INBOX" in q
|
||||
|
||||
def test_multiple_labels_joined_with_or(self):
|
||||
q = self._fn({"labels": ["INBOX", "work"]}, None)
|
||||
assert "label:INBOX OR label:work" in q
|
||||
|
||||
def test_senders(self):
|
||||
q = self._fn({"senders": ["alice@example.com"]}, None)
|
||||
assert "from:alice@example.com" in q
|
||||
|
||||
def test_date_range_from(self):
|
||||
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
|
||||
assert "after:2025/01/15" in q
|
||||
|
||||
def test_date_range_to(self):
|
||||
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
|
||||
assert "before:2025/03/01" in q
|
||||
|
||||
def test_since_overrides_earlier_date_range_from(self):
|
||||
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
|
||||
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||
assert "after:2025/02/01" in q
|
||||
assert "after:2025/01/01" not in q
|
||||
|
||||
def test_date_range_from_overrides_earlier_since(self):
|
||||
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
|
||||
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
|
||||
assert "after:2025/02/01" in q
|
||||
|
||||
def test_invalid_date_ignored(self):
|
||||
"""An invalid date string in filter_config must not raise, just be skipped."""
|
||||
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
|
||||
assert "after:" not in q
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# Gmail client \u2014 body parsing
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
class TestParseBody:
|
||||
"""Unit tests for gmail._parse_body."""
|
||||
|
||||
def setup_method(self):
|
||||
from app.integrations.gmail import _parse_body
|
||||
self._fn = _parse_body
|
||||
|
||||
def _encode(self, text: str) -> str:
|
||||
import base64
|
||||
return base64.urlsafe_b64encode(text.encode()).decode()
|
||||
|
||||
def test_text_plain_extracted(self):
|
||||
payload = {
|
||||
"mimeType": "text/plain",
|
||||
"body": {"data": self._encode("Hello world")},
|
||||
}
|
||||
assert self._fn(payload) == "Hello world"
|
||||
|
||||
def test_text_html_stripped(self):
|
||||
payload = {
|
||||
"mimeType": "text/html",
|
||||
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
|
||||
}
|
||||
result = self._fn(payload)
|
||||
assert "Hello" in result
|
||||
assert "<p>" not in result
|
||||
|
||||
def test_multipart_prefers_plain_over_html(self):
|
||||
plain_data = self._encode("Plain text")
|
||||
html_data = self._encode("<p>HTML text</p>")
|
||||
payload = {
|
||||
"mimeType": "multipart/alternative",
|
||||
"body": {},
|
||||
"parts": [
|
||||
{"mimeType": "text/html", "body": {"data": html_data}},
|
||||
{"mimeType": "text/plain", "body": {"data": plain_data}},
|
||||
],
|
||||
}
|
||||
result = self._fn(payload)
|
||||
assert result == "Plain text"
|
||||
|
||||
def test_empty_payload_returns_empty_string(self):
|
||||
assert self._fn({}) == ""
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# GmailClient.fetch_messages
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
def _make_gmail_message(
|
||||
msg_id: str = "msg001",
|
||||
subject: str = "Test email",
|
||||
sender: str = "alice@example.com",
|
||||
body_text: str = "Hello world",
|
||||
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
|
||||
) -> dict:
|
||||
"""Build a minimal Gmail API message response dict."""
|
||||
import base64
|
||||
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
|
||||
return {
|
||||
"id": msg_id,
|
||||
"labelIds": ["INBOX"],
|
||||
"payload": {
|
||||
"mimeType": "text/plain",
|
||||
"headers": [
|
||||
{"name": "Subject", "value": subject},
|
||||
{"name": "From", "value": sender},
|
||||
{"name": "Date", "value": date},
|
||||
],
|
||||
"body": {"data": body_data},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestGmailClientFetchMessages:
|
||||
"""GmailClient.fetch_messages tests with mocked Google API."""
|
||||
|
||||
def _make_client(self) -> "GmailClient":
|
||||
from app.integrations.gmail import GmailClient
|
||||
return GmailClient(_TOKEN_DICT)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_returns_email_messages(self):
|
||||
client = self._make_client()
|
||||
msg = _make_gmail_message()
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_users = mock_service.users.return_value
|
||||
mock_messages = mock_users.messages.return_value
|
||||
mock_messages.list.return_value.execute.return_value = {
|
||||
"messages": [{"id": "msg001"}]
|
||||
}
|
||||
mock_messages.get.return_value.execute.return_value = msg
|
||||
|
||||
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||
# Simulate to_thread running the sync function and returning results.
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
mock_thread.side_effect = fake_to_thread
|
||||
|
||||
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||
patch("google.auth.transport.requests.Request"), \
|
||||
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||
results = await client.fetch_messages()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].subject == "Test email"
|
||||
assert results[0].sender == "alice@example.com"
|
||||
assert results[0].body_text == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_messages_returns_empty_list(self):
|
||||
client = self._make_client()
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_users = mock_service.users.return_value
|
||||
mock_messages = mock_users.messages.return_value
|
||||
mock_messages.list.return_value.execute.return_value = {"messages": []}
|
||||
|
||||
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
mock_thread.side_effect = fake_to_thread
|
||||
|
||||
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||
patch("google.auth.transport.requests.Request"), \
|
||||
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||
results = await client.fetch_messages()
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_http_error_raises_runtime_error(self):
|
||||
import googleapiclient.errors
|
||||
client = self._make_client()
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_users = mock_service.users.return_value
|
||||
mock_messages = mock_users.messages.return_value
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = 403
|
||||
mock_resp.reason = "Forbidden"
|
||||
mock_messages.list.return_value.execute.side_effect = (
|
||||
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
|
||||
)
|
||||
|
||||
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
mock_thread.side_effect = fake_to_thread
|
||||
|
||||
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||
patch("google.auth.transport.requests.Request"), \
|
||||
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
|
||||
await client.fetch_messages()
|
||||
|
||||
def test_refreshed_credentials_none_when_unchanged(self):
|
||||
client = self._make_client()
|
||||
# Token unchanged — should return None.
|
||||
assert client.refreshed_credentials is None
|
||||
|
||||
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||
client = self._make_client()
|
||||
# Simulate a token refresh by changing the access token on the credentials object.
|
||||
client._credentials.token = "new_access_token_xyz"
|
||||
refreshed = client.refreshed_credentials
|
||||
assert refreshed is not None
|
||||
assert refreshed["token"] == "new_access_token_xyz"
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# MS Graph client \u2014 email filter builder
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
class TestBuildEmailFilter:
|
||||
"""Unit tests for ms_graph._build_email_filter."""
|
||||
|
||||
def setup_method(self):
|
||||
from app.integrations.ms_graph import _build_email_filter
|
||||
self._fn = _build_email_filter
|
||||
|
||||
def test_empty_returns_empty_string(self):
|
||||
assert self._fn(None, None) == ""
|
||||
|
||||
def test_single_sender(self):
|
||||
result = self._fn({"senders": ["alice@example.com"]}, None)
|
||||
assert "from/emailAddress/address eq 'alice@example.com'" in result
|
||||
|
||||
def test_multiple_senders_joined_with_or(self):
|
||||
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
|
||||
assert " or " in result
|
||||
assert "a@x.com" in result
|
||||
assert "b@x.com" in result
|
||||
|
||||
def test_since_adds_received_date_ge_clause(self):
|
||||
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
|
||||
result = self._fn(None, since)
|
||||
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
|
||||
|
||||
def test_date_range_to_adds_received_date_le_clause(self):
|
||||
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
|
||||
assert "receivedDateTime le" in result
|
||||
|
||||
def test_since_overrides_earlier_date_range_from(self):
|
||||
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||
assert "2025-02-01T00:00:00Z" in result
|
||||
assert "2025-01-01" not in result
|
||||
|
||||
def test_invalid_date_ignored(self):
|
||||
result = self._fn({"date_range": {"from": "bad-date"}}, None)
|
||||
assert "receivedDateTime" not in result
|
||||
|
||||
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
# MSGraphClient.fetch_emails
|
||||
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||
|
||||
|
||||
def _make_graph_email(
|
||||
msg_id: str = "email001",
|
||||
subject: str = "Meeting tomorrow",
|
||||
sender_address: str = "boss@company.com",
|
||||
body_content: str = "Please prepare the report.",
|
||||
received: str = "2025-06-01T10:00:00Z",
|
||||
) -> dict:
|
||||
"""Build a minimal MS Graph message item dict."""
|
||||
return {
|
||||
"id": msg_id,
|
||||
"subject": subject,
|
||||
"from": {"emailAddress": {"address": sender_address}},
|
||||
"receivedDateTime": received,
|
||||
"body": {"contentType": "text", "content": body_content},
|
||||
"bodyPreview": body_content[:100],
|
||||
}
|
||||
|
||||
|
||||
def _make_graph_teams_message(
|
||||
msg_id: str = "teams001",
|
||||
content: str = "Stand-up at 9am",
|
||||
sender: str = "alice",
|
||||
channel_id: str = "chan001",
|
||||
created: str = "2025-06-01T08:00:00Z",
|
||||
) -> dict:
|
||||
return {
|
||||
"id": msg_id,
|
||||
"body": {"contentType": "text", "content": content},
|
||||
"from": {"user": {"displayName": sender}},
|
||||
"channelIdentity": {"channelId": channel_id},
|
||||
"createdDateTime": created,
|
||||
}
|
||||
|
||||
|
||||
class TestMSGraphClientFetchEmails:
|
||||
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
||||
|
||||
def _make_client(self) -> "MSGraphClient":
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
return MSGraphClient(_MS_TOKEN_DICT)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_returns_email_messages(self):
|
||||
client = self._make_client()
|
||||
graph_email = _make_graph_email()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"value": [graph_email]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
results = await client.fetch_emails()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].subject == "Meeting tomorrow"
|
||||
assert results[0].sender == "boss@company.com"
|
||||
assert results[0].body_text == "Please prepare the report."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_stops_at_max_emails(self):
|
||||
"""No nextLink in first page \u2014 only one batch returned."""
|
||||
client = self._make_client()
|
||||
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
results = await client.fetch_emails()
|
||||
|
||||
assert len(results) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_401_triggers_token_refresh_and_retries(self):
|
||||
"""On first 401, token refresh is attempted and the request retried."""
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
client = MSGraphClient(_MS_TOKEN_DICT)
|
||||
|
||||
graph_email = _make_graph_email()
|
||||
|
||||
response_401 = MagicMock()
|
||||
response_401.status_code = 401
|
||||
|
||||
response_200 = MagicMock()
|
||||
response_200.status_code = 200
|
||||
response_200.json.return_value = {"value": [graph_email]}
|
||||
response_200.raise_for_status = MagicMock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_get(url, params=None, headers=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return response_401
|
||||
return response_200
|
||||
|
||||
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
|
||||
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = fake_get
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
results = await client.fetch_emails()
|
||||
|
||||
mock_refresh.assert_called_once()
|
||||
assert len(results) == 1
|
||||
|
||||
def test_refreshed_credentials_none_when_token_unchanged(self):
|
||||
client = self._make_client()
|
||||
assert client.refreshed_credentials is None
|
||||
|
||||
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||
client = self._make_client()
|
||||
client._access_token = "new_token_abc"
|
||||
assert client.refreshed_credentials is not None
|
||||
assert client.refreshed_credentials["access_token"] == "new_token_abc"
|
||||
|
||||
|
||||
class TestMSGraphClientFetchMessages:
|
||||
"""MSGraphClient.fetch_messages (Teams) tests."""
|
||||
|
||||
def _make_client(self) -> "MSGraphClient":
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
return MSGraphClient(_MS_TOKEN_DICT)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_returns_chat_messages(self):
|
||||
client = self._make_client()
|
||||
teams_msg = _make_graph_teams_message()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"value": [teams_msg]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
results = await client.fetch_messages()
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Stand-up at 9am"
|
||||
assert results[0].sender == "alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_403_degrades_gracefully(self):
|
||||
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
|
||||
import httpx as _httpx
|
||||
|
||||
client = self._make_client()
|
||||
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = 403
|
||||
http_error = _httpx.HTTPStatusError(
|
||||
"Forbidden", request=MagicMock(), response=error_response
|
||||
)
|
||||
|
||||
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(side_effect=http_error)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
results = await client.fetch_messages()
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_filter_applied(self):
|
||||
"""Messages from non-matching channels are filtered out."""
|
||||
client = self._make_client()
|
||||
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
|
||||
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"value": [matching, non_matching]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
results = await client.fetch_messages(
|
||||
filter_config={"channels": ["dev-channel"]}
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Deploy today"
|
||||
|
||||
|
||||
class TestMSGraphClientRefreshToken:
|
||||
"""MSGraphClient._refresh_access_token with mocked MSAL."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msal_error_raises_runtime_error(self):
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
"error_description": "Refresh token expired",
|
||||
}
|
||||
|
||||
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||
mock_settings.MS_CLIENT_ID = "client_id"
|
||||
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||
mock_settings.MS_TENANT_ID = "common"
|
||||
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
|
||||
await client._refresh_access_token()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_refresh_updates_access_token(self):
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||
"access_token": "new_access_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
}
|
||||
|
||||
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||
mock_settings.MS_CLIENT_ID = "client_id"
|
||||
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||
mock_settings.MS_TENANT_ID = "common"
|
||||
await client._refresh_access_token()
|
||||
|
||||
assert client._access_token == "new_access_token"
|
||||
assert client._refresh_token == "new_refresh_token"
|
||||
Reference in New Issue
Block a user