feat(step-3.6): cloud provider integrations (Gmail, Outlook, Teams)
- Add app/integrations/__init__.py: Fernet token encryption helpers, EmailMessage/ChatMessage dataclasses, get_provider() factory - Add app/integrations/gmail.py: GmailClient with async fetch_messages(), token refresh, configurable label/sender/date filters - Add app/integrations/ms_graph.py: MSGraphClient with fetch_emails() (Outlook) and fetch_messages() (Teams), MSAL token refresh, OData filters - Update app/core/agent_runner.py: replace run_cloud_agent() stub with full 8-step implementation; extend _finalize_run() for cloud config type - Update app/config/settings.py: add OAuth + Fernet encryption settings - Update requirements.txt: google-api-python-client, google-auth-*, msal, cryptography - Add tests/test_integrations.py: 47 tests covering all integration code - Update tests/test_agent_runner.py: replace stub test with 7 real tests All 76 new/updated tests pass.
This commit is contained in:
@@ -437,21 +437,21 @@ Cloud Agent:
|
|||||||
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
|
- **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅
|
||||||
|
|
||||||
### Step 3.6 — Cloud provider integrations
|
### Step 3.6 — Cloud provider integrations
|
||||||
- [ ] Create `app/integrations/gmail.py`:
|
- [x] Create `app/integrations/gmail.py`:
|
||||||
- `GmailClient`:
|
- `GmailClient`:
|
||||||
- `__init__(oauth_token)` — initializes Google API client
|
- `__init__(oauth_token)` — initializes Google API client
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
|
- `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]`
|
||||||
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
|
- `EmailMessage`: `{ id, subject, sender, body_text, date, labels }`
|
||||||
- Handles token refresh via Google OAuth2 refresh flow
|
- Handles token refresh via Google OAuth2 refresh flow
|
||||||
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
|
- Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders`
|
||||||
- [ ] Create `app/integrations/ms_graph.py`:
|
- [x] Create `app/integrations/ms_graph.py`:
|
||||||
- `MSGraphClient`:
|
- `MSGraphClient`:
|
||||||
- `__init__(oauth_token)` — initializes MS Graph client
|
- `__init__(oauth_token)` — initializes MS Graph client
|
||||||
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
|
- `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook)
|
||||||
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
|
- `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams)
|
||||||
- `ChatMessage`: `{ id, content, sender, channel, date }`
|
- `ChatMessage`: `{ id, content, sender, channel, date }`
|
||||||
- Handles token refresh via MSAL
|
- Handles token refresh via MSAL
|
||||||
- [ ] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
|
- [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient`
|
||||||
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
|
- **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal`
|
||||||
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
|
- **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py`
|
||||||
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
|
- **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams.
|
||||||
|
|||||||
@@ -29,6 +29,25 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# GitHub Copilot OAuth token storage directory.
|
||||||
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||||
|
|
||||||
|
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||||
|
GMAIL_CLIENT_ID: str = ""
|
||||||
|
GMAIL_CLIENT_SECRET: str = ""
|
||||||
|
MS_CLIENT_ID: str = ""
|
||||||
|
MS_CLIENT_SECRET: str = ""
|
||||||
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from croniter import croniter
|
from croniter import croniter
|
||||||
@@ -383,7 +383,10 @@ async def run_local_agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent runner (stub) ───────────────────────────────────────────────
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Default lookback window when an agent has never run before.
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
async def run_cloud_agent(
|
async def run_cloud_agent(
|
||||||
@@ -392,26 +395,199 @@ async def run_cloud_agent(
|
|||||||
run_log: AgentRunLog,
|
run_log: AgentRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a cloud connector agent run.
|
"""Execute a cloud connector agent run end-to-end.
|
||||||
|
|
||||||
.. note::
|
Steps:
|
||||||
This is a **stub** — provider integrations (Gmail, Teams, Outlook)
|
|
||||||
are implemented in Step 3.6. The run is immediately marked as an
|
1. Verify the user's device is online — results are pushed to Electron
|
||||||
error with an informative message.
|
via WS tool-call frames. If no device is connected, abort.
|
||||||
|
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
|
||||||
|
3. Instantiate the provider client (Gmail or MS Graph).
|
||||||
|
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||||||
|
the first run) applying ``config.filter_config`` filters.
|
||||||
|
5. For each message/email call ``_extract_items_from_content`` with
|
||||||
|
``config.prompt_template`` to get structured ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call.
|
||||||
|
7. If the provider refreshed its access token, re-encrypt and write it
|
||||||
|
back to ``config.oauth_token_encrypted``.
|
||||||
|
8. Persist the run outcome via ``_finalize_run``.
|
||||||
"""
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip cloud run=%s — no device online for user=%s",
|
||||||
|
run_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["No connected device — cloud agent results cannot be delivered"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Instantiate provider client ────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[str(exc)],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 4. Fetch messages ─────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: provider fetch failed for cloud agent %s: %s",
|
||||||
|
config.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6",
|
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
||||||
config.id,
|
config.id,
|
||||||
|
len(raw_messages),
|
||||||
config.provider,
|
config.provider,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content_text, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc)
|
||||||
|
|
||||||
|
# ── 8. Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
await _finalize_run(
|
await _finalize_run(
|
||||||
run_log,
|
run_log,
|
||||||
status="error",
|
status=final_status,
|
||||||
errors=[
|
items_processed=items_processed,
|
||||||
f"Cloud provider integrations for '{config.provider}' are not yet "
|
items_created=items_created,
|
||||||
"implemented. This feature arrives in Step 3.6."
|
errors=errors,
|
||||||
],
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -519,13 +695,21 @@ async def _finalize_run(
|
|||||||
managed.errors = errors or []
|
managed.errors = errors or []
|
||||||
managed.completed_at = now
|
managed.completed_at = now
|
||||||
|
|
||||||
if update_config_last_run and config_id and config_type == "local":
|
if update_config_last_run and config_id:
|
||||||
cfg_result = await db.execute(
|
if config_type == "local":
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
cfg_result = await db.execute(
|
||||||
)
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
cfg = cfg_result.scalar_one_or_none()
|
)
|
||||||
if cfg:
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
cfg.last_run_at = now
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|||||||
@@ -17,7 +17,10 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||||
@@ -31,6 +34,10 @@ def _api_key_for_model(model: str) -> str | None:
|
|||||||
return settings.ANTHROPIC_API_KEY or None
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
if model.startswith("gemini/") or model.startswith("google/"):
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
return settings.GOOGLE_API_KEY or None
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
|
return None
|
||||||
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||||
return settings.OPENAI_API_KEY or None
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
@@ -55,6 +62,11 @@ def get_llm(
|
|||||||
Sampling temperature. ``0`` = deterministic.
|
Sampling temperature. ``0`` = deterministic.
|
||||||
"""
|
"""
|
||||||
model = model or settings.LLM_MODEL
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
# Point LiteLLM to the custom token directory when configured.
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
model=model,
|
model=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -71,10 +83,22 @@ def get_router_llm(
|
|||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
"""Return a 1536-dim embedding vector for *text* using text-embedding-3-small."""
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||||
|
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||||
|
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||||
|
model names to preserve existing behaviour.
|
||||||
|
"""
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||||
|
# so the provider's auth mechanism is applied correctly.
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
response = await client.embeddings.create(
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
model="text-embedding-3-small",
|
|
||||||
input=text,
|
|
||||||
)
|
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
|||||||
164
app/integrations/__init__.py
Normal file
164
app/integrations/__init__.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||||
|
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||||
|
* ``get_provider()`` — factory that returns the correct client given a
|
||||||
|
provider name and decrypted OAuth credentials dict.
|
||||||
|
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||||
|
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||||
|
|
||||||
|
Encryption rationale
|
||||||
|
--------------------
|
||||||
|
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||||
|
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||||
|
because the backend makes provider API calls on behalf of the user.
|
||||||
|
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||||
|
is never returned to clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Shared message types ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
"""A single email message fetched from Gmail or Outlook."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||||
|
must ensure this is configured before persisting OAuth tokens.
|
||||||
|
"""
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||||
|
|
||||||
|
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||||
|
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: ``token_info`` is not a non-empty dict.
|
||||||
|
"""
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: The encrypted string is invalid or was encrypted with a
|
||||||
|
different key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── Provider factory ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
"""Return the correct provider client for *provider*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
provider:
|
||||||
|
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Unknown provider name.
|
||||||
|
"""
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
335
app/integrations/gmail.py
Normal file
335
app/integrations/gmail.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||||
|
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||||
|
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||||
|
blocking the event loop.
|
||||||
|
|
||||||
|
Token refresh is handled transparently: when the stored access token has
|
||||||
|
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||||
|
token to obtain a fresh one. The caller is responsible for persisting
|
||||||
|
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||||
|
(see ``agent_runner.run_cloud_agent``).
|
||||||
|
|
||||||
|
Credential dict shape (Google OAuth2):
|
||||||
|
{
|
||||||
|
"token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "<client_id>",
|
||||||
|
"client_secret": "<client_secret>",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Gmail search date format — e.g. "after:2025/01/01"
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
|
||||||
|
# Maximum characters of body text forwarded to the LLM.
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||||
|
senders (list[str]): Sender addresses or domains to include
|
||||||
|
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||||
|
|
||||||
|
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||||
|
when it is earlier.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Labels — joined with OR when multiple given.
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
# Senders — each prefixed with "from:".
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
"""Remove HTML tags and decode entities to get plain text."""
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||||
|
|
||||||
|
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||||
|
Returns an empty string if no body can be extracted.
|
||||||
|
"""
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Multipart — prefer text/plain part, fall back to text/html.
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||||
|
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||||
|
``client_id`` + ``client_secret``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||||
|
|
||||||
|
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||||
|
to avoid blocking the async event loop.
|
||||||
|
|
||||||
|
Token refresh is performed automatically when the access token has
|
||||||
|
expired. After the call, ``self.refreshed_credentials`` may be
|
||||||
|
consulted to detect whether new credentials should be persisted.
|
||||||
|
"""
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||||
|
a new dict that should be re-encrypted and written back to the DB.
|
||||||
|
Returns ``None`` if no refresh occurred.
|
||||||
|
"""
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
# Check whether the token changed from what was stored.
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Internal sync worker ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
# Refresh token if needed before building the service.
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# ── List matching message IDs ──────────────────────────────────────
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
logger.debug("gmail: no messages matched query %r", query)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
# ── Fetch individual message details ──────────────────────────────
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
352
app/integrations/ms_graph.py
Normal file
352
app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||||
|
|
||||||
|
Handles two data sources:
|
||||||
|
|
||||||
|
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||||
|
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||||
|
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||||
|
``/me/chats/getAllMessages`` filtered by date.
|
||||||
|
|
||||||
|
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||||
|
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||||
|
dependency) is used for all API calls.
|
||||||
|
|
||||||
|
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||||
|
{
|
||||||
|
"access_token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
# Max items fetched per run.
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
# Max characters of body forwarded to the LLM.
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
"""Strip HTML tags and collapse whitespace."""
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
senders (list[str]): Sender email addresses.
|
||||||
|
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||||
|
folders (list[str]): Folder display names (not directly filterable
|
||||||
|
via OData, so ignored here — callers iterate
|
||||||
|
folder IDs separately if needed; listed for
|
||||||
|
completeness).
|
||||||
|
|
||||||
|
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||||
|
earlier.
|
||||||
|
"""
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Senders.
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted MSAL credential dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
# ── Token management ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||||
|
|
||||||
|
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: MSAL reports an auth error.
|
||||||
|
"""
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
# MSAL may issue a new refresh token.
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
Returns ``None`` if no change was made.
|
||||||
|
"""
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter_config:
|
||||||
|
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||||
|
since:
|
||||||
|
Hard lower-bound on email date (from last agent run).
|
||||||
|
"""
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {} # nextLink already contains encoded params.
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||||
|
|
||||||
|
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||||
|
The ``filter_config.channels`` key is checked as a text-filter on
|
||||||
|
the channel name post-fetch (the API doesn't support channel OData
|
||||||
|
filter directly on ``getAllMessages``).
|
||||||
|
"""
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
# getAllMessages requires specific licensing; degrade gracefully.
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||||
|
"check Teams license or permissions",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# ── Parsers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
@@ -8,6 +8,9 @@ services:
|
|||||||
required: false
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
|
volumes:
|
||||||
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@@ -66,3 +69,4 @@ volumes:
|
|||||||
postgres_data:
|
postgres_data:
|
||||||
minio_data:
|
minio_data:
|
||||||
qdrant_data:
|
qdrant_data:
|
||||||
|
copilot_tokens:
|
||||||
|
|||||||
@@ -25,4 +25,10 @@ moto[s3]>=5.0.0
|
|||||||
pinecone>=5.0.0
|
pinecone>=5.0.0
|
||||||
qdrant-client>=1.7.0
|
qdrant-client>=1.7.0
|
||||||
croniter>=3.0.0
|
croniter>=3.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.29.0
|
||||||
|
google-auth-oauthlib>=1.2.0
|
||||||
|
google-auth-httplib2>=0.2.0
|
||||||
|
msal>=1.28.0
|
||||||
|
cryptography>=42.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
@@ -455,21 +455,232 @@ async def test_run_local_agent_llm_extraction_error():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_cloud_agent_stub_returns_error():
|
async def test_run_cloud_agent_device_offline():
|
||||||
"""Cloud agent stub immediately marks run as error with informative message."""
|
"""Cloud agent aborts immediately when no device is connected."""
|
||||||
config = _make_cloud_config()
|
config = _make_cloud_config()
|
||||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_no_oauth_token():
|
||||||
|
"""Cloud agent errors when no OAuth token is stored."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = None
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
mock_finalize.assert_called_once()
|
_, kwargs = mock_finalize.call_args
|
||||||
_args, kwargs = mock_finalize.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
assert kwargs["status"] == "error"
|
||||||
assert len(kwargs["errors"]) == 1
|
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||||
assert "gmail" in kwargs["errors"][0].lower()
|
|
||||||
assert "3.6" in kwargs["errors"][0]
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_token_decrypt_failure():
|
||||||
|
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
valid_key = _Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_happy_path_gmail():
|
||||||
|
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"client_secret": "csec",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.provider = "gmail"
|
||||||
|
config.prompt_template = "Extract tasks from this email."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sample_email = EmailMessage(
|
||||||
|
id="msg001",
|
||||||
|
subject="Action required",
|
||||||
|
sender="boss@company.com",
|
||||||
|
body_text="Please fix the bug by Friday.",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as mock_int_settings, \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||||
|
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
|
||||||
|
mock_gmail = AsyncMock()
|
||||||
|
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||||
|
mock_gmail.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_extract.assert_called_once()
|
||||||
|
mock_insert.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["config_type"] == "cloud"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_provider_fetch_error():
|
||||||
|
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||||
|
credentials = {"token": "abc"}
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||||
|
mock_provider.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||||
|
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||||
|
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||||
|
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||||
|
|
||||||
|
# Track DB writes via mock async_session.
|
||||||
|
mock_cfg_row = MagicMock()
|
||||||
|
mock_cfg_row.oauth_token_encrypted = None
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||||
|
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||||
|
patch("app.integrations.settings") as mock_int_settings:
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
# The new encrypted token should have been written to the config row.
|
||||||
|
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||||
|
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||||
|
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||||
|
from app.core.agent_runner import _finalize_run
|
||||||
|
|
||||||
|
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||||
|
run_log.id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.last_run_at = None
|
||||||
|
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.merge = AsyncMock(return_value=run_log)
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
config_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="success",
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config_id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CloudAgentConfig.last_run_at should have been set.
|
||||||
|
assert mock_cfg.last_run_at is not None
|
||||||
|
mock_db.commit.assert_called()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
729
tests/test_integrations.py
Normal file
729
tests/test_integrations.py
Normal file
@@ -0,0 +1,729 @@
|
|||||||
|
"""Tests for Step 3.6: cloud provider integration clients.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit \u2014 app/integrations/__init__.py:
|
||||||
|
- encrypt_token / decrypt_token round-trip
|
||||||
|
- decrypt_token raises ValueError on invalid ciphertext
|
||||||
|
- encrypt_token raises ValueError on empty/non-dict input
|
||||||
|
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
|
||||||
|
- get_provider returns GmailClient for 'gmail'
|
||||||
|
- get_provider returns MSGraphClient for 'outlook' and 'teams'
|
||||||
|
- get_provider raises ValueError for unknown provider
|
||||||
|
|
||||||
|
Unit \u2014 app/integrations/gmail.py:
|
||||||
|
- _build_gmail_query with no filter returns empty string
|
||||||
|
- _build_gmail_query with labels builds label: expr
|
||||||
|
- _build_gmail_query with senders builds from: expr
|
||||||
|
- _build_gmail_query with date_range builds after:/before: exprs
|
||||||
|
- _build_gmail_query since overrides date_range.from when more recent
|
||||||
|
- _build_gmail_query date_range.from overrides since when more recent
|
||||||
|
- _parse_body extracts text/plain part
|
||||||
|
- _parse_body extracts text/html part (stripped)
|
||||||
|
- _parse_body recurses into multipart, prefers text/plain
|
||||||
|
- GmailClient.fetch_messages: happy path with mocked service
|
||||||
|
- GmailClient.fetch_messages: no messages returns empty list
|
||||||
|
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
|
||||||
|
- GmailClient.refreshed_credentials: None when token unchanged
|
||||||
|
- GmailClient.refreshed_credentials: returns dict when token changes
|
||||||
|
|
||||||
|
Unit \u2014 app/integrations/ms_graph.py:
|
||||||
|
- _build_email_filter with no filter returns empty string
|
||||||
|
- _build_email_filter with senders builds OData from clause
|
||||||
|
- _build_email_filter with since builds receivedDateTime ge clause
|
||||||
|
- MSGraphClient.fetch_emails: happy path with mocked httpx
|
||||||
|
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
|
||||||
|
- MSGraphClient.fetch_messages: happy path with mocked httpx
|
||||||
|
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
|
||||||
|
- MSGraphClient.refreshed_credentials: None when token unchanged
|
||||||
|
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.integrations import (
|
||||||
|
ChatMessage,
|
||||||
|
EmailMessage,
|
||||||
|
decrypt_token,
|
||||||
|
encrypt_token,
|
||||||
|
get_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Helpers
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
|
||||||
|
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
|
||||||
|
# so we generate a proper one below)
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet # noqa: E402
|
||||||
|
|
||||||
|
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
|
||||||
|
|
||||||
|
_TOKEN_DICT = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "client_id_123",
|
||||||
|
"client_secret": "client_secret_456",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_MS_TOKEN_DICT = {
|
||||||
|
"access_token": "ms_access_abc",
|
||||||
|
"refresh_token": "ms_refresh_xyz",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read offline_access",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# encrypt_token / decrypt_token
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenEncryption:
|
||||||
|
"""encrypt_token / decrypt_token round-trip tests."""
|
||||||
|
|
||||||
|
def test_round_trip(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
encrypted = encrypt_token(_TOKEN_DICT)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
|
||||||
|
recovered = decrypt_token(encrypted)
|
||||||
|
assert recovered == _TOKEN_DICT
|
||||||
|
|
||||||
|
def test_decrypt_invalid_ciphertext_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_token("this-is-not-valid-fernet-ciphertext")
|
||||||
|
|
||||||
|
def test_decrypt_wrong_key_raises_value_error(self):
|
||||||
|
"""Decrypting with a different key must fail with ValueError."""
|
||||||
|
other_key = _Fernet.generate_key().decode("utf-8")
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
encrypted = encrypt_token(_TOKEN_DICT)
|
||||||
|
with patch("app.integrations.settings") as mock_settings2:
|
||||||
|
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_token(encrypted)
|
||||||
|
|
||||||
|
def test_encrypt_empty_dict_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="non-empty dict"):
|
||||||
|
encrypt_token({})
|
||||||
|
|
||||||
|
def test_encrypt_non_dict_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="non-empty dict"):
|
||||||
|
encrypt_token("not-a-dict") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_missing_key_raises_runtime_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = ""
|
||||||
|
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
|
||||||
|
encrypt_token(_TOKEN_DICT)
|
||||||
|
|
||||||
|
def test_email_message_as_text(self):
|
||||||
|
msg = EmailMessage(
|
||||||
|
id="m1",
|
||||||
|
subject="Hello",
|
||||||
|
sender="alice@example.com",
|
||||||
|
body_text="Test body",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
text = msg.as_text
|
||||||
|
assert "From: alice@example.com" in text
|
||||||
|
assert "Subject: Hello" in text
|
||||||
|
assert "Test body" in text
|
||||||
|
|
||||||
|
def test_chat_message_as_text(self):
|
||||||
|
msg = ChatMessage(
|
||||||
|
id="c1",
|
||||||
|
content="Buy milk",
|
||||||
|
sender="bob",
|
||||||
|
channel="general",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
text = msg.as_text
|
||||||
|
assert "From: bob" in text
|
||||||
|
assert "channel: general" in text
|
||||||
|
assert "Buy milk" in text
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# get_provider factory
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetProvider:
|
||||||
|
def test_gmail_returns_gmail_client(self):
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
|
||||||
|
client = get_provider("gmail", _TOKEN_DICT)
|
||||||
|
assert isinstance(client, GmailClient)
|
||||||
|
|
||||||
|
def test_outlook_returns_ms_graph_client(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
client = get_provider("outlook", _MS_TOKEN_DICT)
|
||||||
|
assert isinstance(client, MSGraphClient)
|
||||||
|
|
||||||
|
def test_teams_returns_ms_graph_client(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
client = get_provider("teams", _MS_TOKEN_DICT)
|
||||||
|
assert isinstance(client, MSGraphClient)
|
||||||
|
|
||||||
|
def test_unknown_provider_raises_value_error(self):
|
||||||
|
with pytest.raises(ValueError, match="Unknown cloud provider"):
|
||||||
|
get_provider("slack", {})
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Gmail client \u2014 query builder
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildGmailQuery:
|
||||||
|
"""Unit tests for gmail._build_gmail_query."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.gmail import _build_gmail_query
|
||||||
|
self._fn = _build_gmail_query
|
||||||
|
|
||||||
|
def test_empty_returns_empty_string(self):
|
||||||
|
assert self._fn(None, None) == ""
|
||||||
|
|
||||||
|
def test_single_label(self):
|
||||||
|
q = self._fn({"labels": ["INBOX"]}, None)
|
||||||
|
assert "label:INBOX" in q
|
||||||
|
|
||||||
|
def test_multiple_labels_joined_with_or(self):
|
||||||
|
q = self._fn({"labels": ["INBOX", "work"]}, None)
|
||||||
|
assert "label:INBOX OR label:work" in q
|
||||||
|
|
||||||
|
def test_senders(self):
|
||||||
|
q = self._fn({"senders": ["alice@example.com"]}, None)
|
||||||
|
assert "from:alice@example.com" in q
|
||||||
|
|
||||||
|
def test_date_range_from(self):
|
||||||
|
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
|
||||||
|
assert "after:2025/01/15" in q
|
||||||
|
|
||||||
|
def test_date_range_to(self):
|
||||||
|
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
|
||||||
|
assert "before:2025/03/01" in q
|
||||||
|
|
||||||
|
def test_since_overrides_earlier_date_range_from(self):
|
||||||
|
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
|
||||||
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||||
|
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||||
|
assert "after:2025/02/01" in q
|
||||||
|
assert "after:2025/01/01" not in q
|
||||||
|
|
||||||
|
def test_date_range_from_overrides_earlier_since(self):
|
||||||
|
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
|
||||||
|
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||||
|
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
|
||||||
|
assert "after:2025/02/01" in q
|
||||||
|
|
||||||
|
def test_invalid_date_ignored(self):
|
||||||
|
"""An invalid date string in filter_config must not raise, just be skipped."""
|
||||||
|
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
|
||||||
|
assert "after:" not in q
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Gmail client \u2014 body parsing
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseBody:
|
||||||
|
"""Unit tests for gmail._parse_body."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.gmail import _parse_body
|
||||||
|
self._fn = _parse_body
|
||||||
|
|
||||||
|
def _encode(self, text: str) -> str:
|
||||||
|
import base64
|
||||||
|
return base64.urlsafe_b64encode(text.encode()).decode()
|
||||||
|
|
||||||
|
def test_text_plain_extracted(self):
|
||||||
|
payload = {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"body": {"data": self._encode("Hello world")},
|
||||||
|
}
|
||||||
|
assert self._fn(payload) == "Hello world"
|
||||||
|
|
||||||
|
def test_text_html_stripped(self):
|
||||||
|
payload = {
|
||||||
|
"mimeType": "text/html",
|
||||||
|
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
|
||||||
|
}
|
||||||
|
result = self._fn(payload)
|
||||||
|
assert "Hello" in result
|
||||||
|
assert "<p>" not in result
|
||||||
|
|
||||||
|
def test_multipart_prefers_plain_over_html(self):
|
||||||
|
plain_data = self._encode("Plain text")
|
||||||
|
html_data = self._encode("<p>HTML text</p>")
|
||||||
|
payload = {
|
||||||
|
"mimeType": "multipart/alternative",
|
||||||
|
"body": {},
|
||||||
|
"parts": [
|
||||||
|
{"mimeType": "text/html", "body": {"data": html_data}},
|
||||||
|
{"mimeType": "text/plain", "body": {"data": plain_data}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = self._fn(payload)
|
||||||
|
assert result == "Plain text"
|
||||||
|
|
||||||
|
def test_empty_payload_returns_empty_string(self):
|
||||||
|
assert self._fn({}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# GmailClient.fetch_messages
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gmail_message(
|
||||||
|
msg_id: str = "msg001",
|
||||||
|
subject: str = "Test email",
|
||||||
|
sender: str = "alice@example.com",
|
||||||
|
body_text: str = "Hello world",
|
||||||
|
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a minimal Gmail API message response dict."""
|
||||||
|
import base64
|
||||||
|
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"labelIds": ["INBOX"],
|
||||||
|
"payload": {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"headers": [
|
||||||
|
{"name": "Subject", "value": subject},
|
||||||
|
{"name": "From", "value": sender},
|
||||||
|
{"name": "Date", "value": date},
|
||||||
|
],
|
||||||
|
"body": {"data": body_data},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGmailClientFetchMessages:
|
||||||
|
"""GmailClient.fetch_messages tests with mocked Google API."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "GmailClient":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_email_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
msg = _make_gmail_message()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_messages.list.return_value.execute.return_value = {
|
||||||
|
"messages": [{"id": "msg001"}]
|
||||||
|
}
|
||||||
|
mock_messages.get.return_value.execute.return_value = msg
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
# Simulate to_thread running the sync function and returning results.
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].subject == "Test email"
|
||||||
|
assert results[0].sender == "alice@example.com"
|
||||||
|
assert results[0].body_text == "Hello world"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_messages_returns_empty_list(self):
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_messages.list.return_value.execute.return_value = {"messages": []}
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_http_error_raises_runtime_error(self):
|
||||||
|
import googleapiclient.errors
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status = 403
|
||||||
|
mock_resp.reason = "Forbidden"
|
||||||
|
mock_messages.list.return_value.execute.side_effect = (
|
||||||
|
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
|
||||||
|
await client.fetch_messages()
|
||||||
|
|
||||||
|
def test_refreshed_credentials_none_when_unchanged(self):
|
||||||
|
client = self._make_client()
|
||||||
|
# Token unchanged — should return None.
|
||||||
|
assert client.refreshed_credentials is None
|
||||||
|
|
||||||
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||||
|
client = self._make_client()
|
||||||
|
# Simulate a token refresh by changing the access token on the credentials object.
|
||||||
|
client._credentials.token = "new_access_token_xyz"
|
||||||
|
refreshed = client.refreshed_credentials
|
||||||
|
assert refreshed is not None
|
||||||
|
assert refreshed["token"] == "new_access_token_xyz"
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# MS Graph client \u2014 email filter builder
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildEmailFilter:
|
||||||
|
"""Unit tests for ms_graph._build_email_filter."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.ms_graph import _build_email_filter
|
||||||
|
self._fn = _build_email_filter
|
||||||
|
|
||||||
|
def test_empty_returns_empty_string(self):
|
||||||
|
assert self._fn(None, None) == ""
|
||||||
|
|
||||||
|
def test_single_sender(self):
|
||||||
|
result = self._fn({"senders": ["alice@example.com"]}, None)
|
||||||
|
assert "from/emailAddress/address eq 'alice@example.com'" in result
|
||||||
|
|
||||||
|
def test_multiple_senders_joined_with_or(self):
|
||||||
|
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
|
||||||
|
assert " or " in result
|
||||||
|
assert "a@x.com" in result
|
||||||
|
assert "b@x.com" in result
|
||||||
|
|
||||||
|
def test_since_adds_received_date_ge_clause(self):
|
||||||
|
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
|
||||||
|
result = self._fn(None, since)
|
||||||
|
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
|
||||||
|
|
||||||
|
def test_date_range_to_adds_received_date_le_clause(self):
|
||||||
|
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
|
||||||
|
assert "receivedDateTime le" in result
|
||||||
|
|
||||||
|
def test_since_overrides_earlier_date_range_from(self):
|
||||||
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||||
|
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||||
|
assert "2025-02-01T00:00:00Z" in result
|
||||||
|
assert "2025-01-01" not in result
|
||||||
|
|
||||||
|
def test_invalid_date_ignored(self):
|
||||||
|
result = self._fn({"date_range": {"from": "bad-date"}}, None)
|
||||||
|
assert "receivedDateTime" not in result
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# MSGraphClient.fetch_emails
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_email(
|
||||||
|
msg_id: str = "email001",
|
||||||
|
subject: str = "Meeting tomorrow",
|
||||||
|
sender_address: str = "boss@company.com",
|
||||||
|
body_content: str = "Please prepare the report.",
|
||||||
|
received: str = "2025-06-01T10:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a minimal MS Graph message item dict."""
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"subject": subject,
|
||||||
|
"from": {"emailAddress": {"address": sender_address}},
|
||||||
|
"receivedDateTime": received,
|
||||||
|
"body": {"contentType": "text", "content": body_content},
|
||||||
|
"bodyPreview": body_content[:100],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_teams_message(
|
||||||
|
msg_id: str = "teams001",
|
||||||
|
content: str = "Stand-up at 9am",
|
||||||
|
sender: str = "alice",
|
||||||
|
channel_id: str = "chan001",
|
||||||
|
created: str = "2025-06-01T08:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"body": {"contentType": "text", "content": content},
|
||||||
|
"from": {"user": {"displayName": sender}},
|
||||||
|
"channelIdentity": {"channelId": channel_id},
|
||||||
|
"createdDateTime": created,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientFetchEmails:
|
||||||
|
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "MSGraphClient":
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_email_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
graph_email = _make_graph_email()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [graph_email]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].subject == "Meeting tomorrow"
|
||||||
|
assert results[0].sender == "boss@company.com"
|
||||||
|
assert results[0].body_text == "Please prepare the report."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pagination_stops_at_max_emails(self):
|
||||||
|
"""No nextLink in first page \u2014 only one batch returned."""
|
||||||
|
client = self._make_client()
|
||||||
|
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_401_triggers_token_refresh_and_retries(self):
|
||||||
|
"""On first 401, token refresh is attempted and the request retried."""
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
graph_email = _make_graph_email()
|
||||||
|
|
||||||
|
response_401 = MagicMock()
|
||||||
|
response_401.status_code = 401
|
||||||
|
|
||||||
|
response_200 = MagicMock()
|
||||||
|
response_200.status_code = 200
|
||||||
|
response_200.json.return_value = {"value": [graph_email]}
|
||||||
|
response_200.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_get(url, params=None, headers=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return response_401
|
||||||
|
return response_200
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
|
||||||
|
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = fake_get
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_refreshed_credentials_none_when_token_unchanged(self):
|
||||||
|
client = self._make_client()
|
||||||
|
assert client.refreshed_credentials is None
|
||||||
|
|
||||||
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||||
|
client = self._make_client()
|
||||||
|
client._access_token = "new_token_abc"
|
||||||
|
assert client.refreshed_credentials is not None
|
||||||
|
assert client.refreshed_credentials["access_token"] == "new_token_abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientFetchMessages:
|
||||||
|
"""MSGraphClient.fetch_messages (Teams) tests."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "MSGraphClient":
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_chat_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
teams_msg = _make_graph_teams_message()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [teams_msg]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content == "Stand-up at 9am"
|
||||||
|
assert results[0].sender == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_403_degrades_gracefully(self):
|
||||||
|
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
|
||||||
|
import httpx as _httpx
|
||||||
|
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
error_response = MagicMock()
|
||||||
|
error_response.status_code = 403
|
||||||
|
http_error = _httpx.HTTPStatusError(
|
||||||
|
"Forbidden", request=MagicMock(), response=error_response
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(side_effect=http_error)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_channel_filter_applied(self):
|
||||||
|
"""Messages from non-matching channels are filtered out."""
|
||||||
|
client = self._make_client()
|
||||||
|
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
|
||||||
|
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [matching, non_matching]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages(
|
||||||
|
filter_config={"channels": ["dev-channel"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content == "Deploy today"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientRefreshToken:
|
||||||
|
"""MSGraphClient._refresh_access_token with mocked MSAL."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_msal_error_raises_runtime_error(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": "Refresh token expired",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||||
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||||
|
mock_settings.MS_CLIENT_ID = "client_id"
|
||||||
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||||
|
mock_settings.MS_TENANT_ID = "common"
|
||||||
|
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
|
||||||
|
await client._refresh_access_token()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_refresh_updates_access_token(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||||
|
"access_token": "new_access_token",
|
||||||
|
"refresh_token": "new_refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||||
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||||
|
mock_settings.MS_CLIENT_ID = "client_id"
|
||||||
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||||
|
mock_settings.MS_TENANT_ID = "common"
|
||||||
|
await client._refresh_access_token()
|
||||||
|
|
||||||
|
assert client._access_token == "new_access_token"
|
||||||
|
assert client._refresh_token == "new_refresh_token"
|
||||||
Reference in New Issue
Block a user