diff --git a/AI_REFACTOR_PLAN.md b/AI_REFACTOR_PLAN.md index 9781fe2..66f09f4 100644 --- a/AI_REFACTOR_PLAN.md +++ b/AI_REFACTOR_PLAN.md @@ -437,21 +437,21 @@ Cloud Agent: - **Outcome:** Users configure AI prompts through guided conversation. Journey can refine an existing config when `agent_id` is provided. ✅ ### Step 3.6 — Cloud provider integrations -- [ ] Create `app/integrations/gmail.py`: +- [x] Create `app/integrations/gmail.py`: - `GmailClient`: - `__init__(oauth_token)` — initializes Google API client - `async fetch_messages(filter_config, since: datetime) -> list[EmailMessage]` - `EmailMessage`: `{ id, subject, sender, body_text, date, labels }` - Handles token refresh via Google OAuth2 refresh flow - Respects `filter_config.labels`, `filter_config.date_range`, `filter_config.senders` -- [ ] Create `app/integrations/ms_graph.py`: +- [x] Create `app/integrations/ms_graph.py`: - `MSGraphClient`: - `__init__(oauth_token)` — initializes MS Graph client - `async fetch_emails(filter_config, since: datetime) -> list[EmailMessage]` (Outlook) - `async fetch_messages(filter_config, since: datetime) -> list[ChatMessage]` (Teams) - `ChatMessage`: `{ id, content, sender, channel, date }` - Handles token refresh via MSAL -- [ ] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient` +- [x] Create `app/integrations/__init__.py` — factory: `get_provider(provider_name) -> GmailClient | MSGraphClient` - **Dependencies:** `google-api-python-client`, `google-auth-oauthlib`, `msgraph-sdk`, `msal` - **Files:** `app/integrations/gmail.py`, `app/integrations/ms_graph.py`, `app/integrations/__init__.py` - **Outcome:** Backend can fetch emails/messages from Gmail, Outlook, and Teams. diff --git a/app/config/settings.py b/app/config/settings.py index b5e181b..886d2e5 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -29,6 +29,25 @@ class Settings(BaseSettings): LLM_MODEL: str = "gpt-4o" LLM_ROUTER_MODEL: str = "gpt-4o-mini" + LLM_EMBED_MODEL: str = "text-embedding-3-small" + + # GitHub Copilot OAuth token storage directory. + # Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot). + # In Docker, set this to a path backed by a named volume so tokens survive restarts. + GITHUB_COPILOT_TOKEN_DIR: str = "" + + # OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows. + GMAIL_CLIENT_ID: str = "" + GMAIL_CLIENT_SECRET: str = "" + MS_CLIENT_ID: str = "" + MS_CLIENT_SECRET: str = "" + # MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts). + MS_TENANT_ID: str = "common" + + # Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth + # tokens stored in cloud_agent_configs.oauth_token_encrypted. + # Generate with: from cryptography.fernet import Fernet; Fernet.generate_key() + OAUTH_ENCRYPTION_KEY: str = "" CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] diff --git a/app/core/agent_runner.py b/app/core/agent_runner.py index d6e9cd5..b8b8242 100644 --- a/app/core/agent_runner.py +++ b/app/core/agent_runner.py @@ -29,7 +29,7 @@ import asyncio import json import logging import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Any from croniter import croniter @@ -383,7 +383,10 @@ async def run_local_agent( ) -# ── Cloud agent runner (stub) ─────────────────────────────────────────────── +# ── Cloud agent runner ───────────────────────────────────────────────────── + +# Default lookback window when an agent has never run before. +_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7 async def run_cloud_agent( @@ -392,26 +395,199 @@ async def run_cloud_agent( run_log: AgentRunLog, device_mgr: DeviceConnectionManager, ) -> None: - """Execute a cloud connector agent run. + """Execute a cloud connector agent run end-to-end. - .. note:: - This is a **stub** — provider integrations (Gmail, Teams, Outlook) - are implemented in Step 3.6. The run is immediately marked as an - error with an informative message. + Steps: + + 1. Verify the user's device is online — results are pushed to Electron + via WS tool-call frames. If no device is connected, abort. + 2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``. + 3. Instantiate the provider client (Gmail or MS Graph). + 4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for + the first run) applying ``config.filter_config`` filters. + 5. For each message/email call ``_extract_items_from_content`` with + ``config.prompt_template`` to get structured ``{table, data}`` items. + 6. Push each item to Electron as an ``insert`` tool-call. + 7. If the provider refreshed its access token, re-encrypt and write it + back to ``config.oauth_token_encrypted``. + 8. Persist the run outcome via ``_finalize_run``. """ + run_id = run_log.id + + # ── 1. Device online check ───────────────────────────────────────── + if not device_mgr.is_online(user_id): + logger.info( + "agent_runner: skip cloud run=%s — no device online for user=%s", + run_id, + user_id, + ) + await _finalize_run( + run_log, + status="error", + errors=["No connected device — cloud agent results cannot be delivered"], + ) + return + + # ── 2. Decrypt OAuth token ───────────────────────────────────────── + from app.integrations import decrypt_token, encrypt_token, get_provider + + if not config.oauth_token_encrypted: + await _finalize_run( + run_log, + status="error", + errors=[f"No OAuth token stored for cloud agent '{config.name}'"], + ) + return + + try: + credentials_info = decrypt_token(config.oauth_token_encrypted) + except ValueError as exc: + logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc) + await _finalize_run( + run_log, + status="error", + errors=[f"Failed to decrypt OAuth token: {exc}"], + ) + return + + # ── 3. Instantiate provider client ──────────────────────────────── + try: + provider = get_provider(config.provider, credentials_info) + except ValueError as exc: + await _finalize_run( + run_log, + status="error", + errors=[str(exc)], + ) + return + + # ── 4. Fetch messages ───────────────────────────────────────────── + since: datetime | None = config.last_run_at + if since is None: + since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS) + if since.tzinfo is None: + since = since.replace(tzinfo=timezone.utc) + + errors: list[str] = [] + items_processed = 0 + items_created = 0 + + try: + if config.provider == "gmail": + raw_messages = await provider.fetch_messages( # type: ignore[union-attr] + filter_config=config.filter_config, + since=since, + ) + elif config.provider == "outlook": + raw_messages = await provider.fetch_emails( # type: ignore[union-attr] + filter_config=config.filter_config, + since=since, + ) + elif config.provider == "teams": + raw_messages = await provider.fetch_messages( # type: ignore[union-attr] + filter_config=config.filter_config, + since=since, + ) + else: + raw_messages = [] + except RuntimeError as exc: + logger.error( + "agent_runner: provider fetch failed for cloud agent %s: %s", + config.id, + exc, + ) + await _finalize_run( + run_log, + status="error", + errors=[f"Provider fetch failed: {exc}"], + update_config_last_run=True, + config_id=config.id, + config_type="cloud", + ) + return + logger.info( - "agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6", + "agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s", config.id, + len(raw_messages), config.provider, user_id, ) + + # ── 5–6. Extract + insert ───────────────────────────────────────── + for msg in raw_messages: + content_text = msg.as_text + if not content_text: + continue + items_processed += 1 + try: + extracted = await _extract_items_from_content( + config.prompt_template, content_text, config.data_types + ) + except Exception as exc: + errors.append(f"LLM extraction error for message {msg.id!r}: {exc}") + continue + + for item in extracted: + try: + result = await _send_insert_to_client( + user_id, item["table"], item["data"], device_mgr + ) + if result.get("error"): + errors.append( + f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}" + ) + else: + items_created += 1 + except asyncio.TimeoutError: + errors.append( + f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})" + ) + except RuntimeError as exc: + errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}") + + # ── 7. Persist refreshed token (if any) ─────────────────────────── + refreshed = getattr(provider, "refreshed_credentials", None) + if refreshed: + try: + new_encrypted = encrypt_token(refreshed) + async with async_session() as db: + cfg_result = await db.execute( + select(CloudAgentConfig).where(CloudAgentConfig.id == config.id) + ) + cfg_row = cfg_result.scalar_one_or_none() + if cfg_row: + cfg_row.oauth_token_encrypted = new_encrypted + await db.commit() + logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id) + except Exception as exc: + logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc) + + # ── 8. Finalise ──────────────────────────────────────────────────── + if errors and items_created == 0: + final_status = "error" + elif errors: + final_status = "partial" + else: + final_status = "success" + await _finalize_run( run_log, - status="error", - errors=[ - f"Cloud provider integrations for '{config.provider}' are not yet " - "implemented. This feature arrives in Step 3.6." - ], + status=final_status, + items_processed=items_processed, + items_created=items_created, + errors=errors, + update_config_last_run=True, + config_id=config.id, + config_type="cloud", + ) + logger.info( + "agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d", + run_id, + final_status, + items_processed, + items_created, + len(errors), ) @@ -519,13 +695,21 @@ async def _finalize_run( managed.errors = errors or [] managed.completed_at = now - if update_config_last_run and config_id and config_type == "local": - cfg_result = await db.execute( - select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) - ) - cfg = cfg_result.scalar_one_or_none() - if cfg: - cfg.last_run_at = now + if update_config_last_run and config_id: + if config_type == "local": + cfg_result = await db.execute( + select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) + ) + cfg = cfg_result.scalar_one_or_none() + if cfg: + cfg.last_run_at = now + elif config_type == "cloud": + cfg_result = await db.execute( + select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) + ) + cfg = cfg_result.scalar_one_or_none() + if cfg: + cfg.last_run_at = now await db.commit() except Exception as exc: diff --git a/app/core/llm.py b/app/core/llm.py index 0a717a2..80e14a5 100644 --- a/app/core/llm.py +++ b/app/core/llm.py @@ -17,7 +17,10 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env`` from __future__ import annotations +import os + from openai import AsyncOpenAI +import litellm from langchain_openai import ChatOpenAI from litellm import get_supported_openai_params # noqa: F401 – validates install @@ -31,6 +34,10 @@ def _api_key_for_model(model: str) -> str | None: return settings.ANTHROPIC_API_KEY or None if model.startswith("gemini/") or model.startswith("google/"): return settings.GOOGLE_API_KEY or None + if model.startswith("github_copilot/"): + # GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM. + # No API key is required; returning None lets LiteLLM handle auth. + return None # Default: OpenAI-compatible (covers plain model names like "gpt-4o") return settings.OPENAI_API_KEY or None @@ -55,6 +62,11 @@ def get_llm( Sampling temperature. ``0`` = deterministic. """ model = model or settings.LLM_MODEL + + # Point LiteLLM to the custom token directory when configured. + if settings.GITHUB_COPILOT_TOKEN_DIR: + os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR) + return ChatOpenAI( model=model, temperature=temperature, @@ -71,10 +83,22 @@ def get_router_llm( async def embed(text: str) -> list[float]: - """Return a 1536-dim embedding vector for *text* using text-embedding-3-small.""" + """Return an embedding vector for *text*. + + Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env`` + (e.g. ``github_copilot/text-embedding-3-small``) applies here without any + code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI + model names to preserve existing behaviour. + """ + model = settings.LLM_EMBED_MODEL + + if model.startswith("github_copilot/") or "/" in model: + # Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.) + # so the provider's auth mechanism is applied correctly. + response = await litellm.aembedding(model=model, input=[text]) + return response.data[0]["embedding"] + + # Plain OpenAI model name — use the raw AsyncOpenAI client (existing path). client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) - response = await client.embeddings.create( - model="text-embedding-3-small", - input=text, - ) + response = await client.embeddings.create(model=model, input=text) return response.data[0].embedding diff --git a/app/integrations/__init__.py b/app/integrations/__init__.py new file mode 100644 index 0000000..ff662aa --- /dev/null +++ b/app/integrations/__init__.py @@ -0,0 +1,164 @@ +"""Cloud provider integration utilities. + +Provides: + * Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by + both the Gmail and MS Graph clients and consumed by ``agent_runner``. + * ``get_provider()`` — factory that returns the correct client given a + provider name and decrypted OAuth credentials dict. + * ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest + encryption for OAuth tokens stored in ``cloud_agent_configs``. + +Encryption rationale +-------------------- +Unlike user content (which is E2E-encrypted client-side and **never** +decrypted server-side), OAuth tokens *must* be decrypted server-side +because the backend makes provider API calls on behalf of the user. +The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it +is never returned to clients. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import TYPE_CHECKING + +from cryptography.fernet import Fernet, InvalidToken + +from app.config.settings import settings + +if TYPE_CHECKING: + from app.integrations.gmail import GmailClient + from app.integrations.ms_graph import MSGraphClient + +logger = logging.getLogger(__name__) + +# ── Shared message types ────────────────────────────────────────────────── + + +@dataclass +class EmailMessage: + """A single email message fetched from Gmail or Outlook.""" + + id: str + subject: str + sender: str + body_text: str + date: datetime + labels: list[str] = field(default_factory=list) + + @property + def as_text(self) -> str: + """Return a human-readable text representation for LLM extraction.""" + date_str = self.date.strftime("%Y-%m-%d %H:%M") + labels_str = f" [{', '.join(self.labels)}]" if self.labels else "" + return ( + f"From: {self.sender}\n" + f"Date: {date_str}{labels_str}\n" + f"Subject: {self.subject}\n\n" + f"{self.body_text}" + ) + + +@dataclass +class ChatMessage: + """A single Teams chat or channel message fetched from MS Graph.""" + + id: str + content: str + sender: str + channel: str | None + date: datetime + + @property + def as_text(self) -> str: + """Return a human-readable text representation for LLM extraction.""" + date_str = self.date.strftime("%Y-%m-%d %H:%M") + channel_str = f" [channel: {self.channel}]" if self.channel else "" + return ( + f"From: {self.sender}\n" + f"Date: {date_str}{channel_str}\n\n" + f"{self.content}" + ) + + +# ── Fernet helpers ──────────────────────────────────────────────────────── + + +def _get_fernet() -> Fernet: + """Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``. + + Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers + must ensure this is configured before persisting OAuth tokens. + """ + key = settings.OAUTH_ENCRYPTION_KEY + if not key: + raise RuntimeError( + "OAUTH_ENCRYPTION_KEY is not set. " + "Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\"" + ) + return Fernet(key.encode() if isinstance(key, str) else key) + + +def encrypt_token(token_info: dict) -> str: + """Fernet-encrypt an OAuth credential dict and return a base64 string. + + Stores the full ``{access_token, refresh_token, token_uri, client_id, + client_secret, scopes, expiry}`` dict (or equivalent MSAL shape). + + Raises: + RuntimeError: OAUTH_ENCRYPTION_KEY is not configured. + ValueError: ``token_info`` is not a non-empty dict. + """ + if not isinstance(token_info, dict) or not token_info: + raise ValueError("token_info must be a non-empty dict") + plaintext = json.dumps(token_info).encode("utf-8") + return _get_fernet().encrypt(plaintext).decode("utf-8") + + +def decrypt_token(encrypted: str) -> dict: + """Decrypt a Fernet-encrypted token string and return the credential dict. + + Raises: + RuntimeError: OAUTH_ENCRYPTION_KEY is not configured. + ValueError: The encrypted string is invalid or was encrypted with a + different key. + """ + try: + plaintext = _get_fernet().decrypt(encrypted.encode("utf-8")) + return json.loads(plaintext) + except (InvalidToken, json.JSONDecodeError) as exc: + raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc + + +# ── Provider factory ────────────────────────────────────────────────────── + + +def get_provider( + provider: str, + credentials_info: dict, +) -> "GmailClient | MSGraphClient": + """Return the correct provider client for *provider*. + + Parameters + ---------- + provider: + One of ``"gmail"``, ``"outlook"``, ``"teams"``. + credentials_info: + Decrypted OAuth credential dict (Google or Microsoft shape). + + Raises: + ValueError: Unknown provider name. + """ + if provider == "gmail": + from app.integrations.gmail import GmailClient + return GmailClient(credentials_info) + if provider in {"outlook", "teams"}: + from app.integrations.ms_graph import MSGraphClient + return MSGraphClient(credentials_info) + raise ValueError( + f"Unknown cloud provider {provider!r}. " + "Supported: 'gmail', 'outlook', 'teams'." + ) diff --git a/app/integrations/gmail.py b/app/integrations/gmail.py new file mode 100644 index 0000000..78ce858 --- /dev/null +++ b/app/integrations/gmail.py @@ -0,0 +1,335 @@ +"""Gmail API client for cloud agent integration. + +Wraps the Google Gmail REST API to fetch email messages matching a +``filter_config`` dict. Uses the official ``google-api-python-client`` +library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid +blocking the event loop. + +Token refresh is handled transparently: when the stored access token has +expired, ``google.auth.transport.requests.Request`` will use the refresh +token to obtain a fresh one. The caller is responsible for persisting +any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted`` +(see ``agent_runner.run_cloud_agent``). + +Credential dict shape (Google OAuth2): + { + "token": "", + "refresh_token": "", + "token_uri": "https://oauth2.googleapis.com/token", + "client_id": "", + "client_secret": "", + "scopes": ["https://www.googleapis.com/auth/gmail.readonly"], + "expiry": "2025-01-01T00:00:00Z" # optional ISO-8601 + } +""" + +from __future__ import annotations + +import asyncio +import base64 +import email +import html +import logging +import re +from datetime import datetime, timezone +from typing import Any + +from app.integrations import EmailMessage + +logger = logging.getLogger(__name__) + +# Gmail search date format — e.g. "after:2025/01/01" +_GMAIL_DATE_FMT = "%Y/%m/%d" + +# Maximum characters of body text forwarded to the LLM. +_BODY_TRUNCATE = 8_000 + +# Maximum messages retrieved per run (prevents runaway quota usage). +_MAX_MESSAGES = 200 + + +def _build_gmail_query( + filter_config: dict[str, Any] | None, + since: datetime | None, +) -> str: + """Build a Gmail search query string from *filter_config* and *since*. + + Supported ``filter_config`` keys: + labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]`` + senders (list[str]): Sender addresses or domains to include + date_range (dict): ``{from: "", to: ""}`` + + A hard ``since`` date (from last run) always overrides ``date_range.from`` + when it is earlier. + """ + parts: list[str] = [] + cfg = filter_config or {} + + # Labels — joined with OR when multiple given. + labels: list[str] = cfg.get("labels", []) + if labels: + if len(labels) == 1: + parts.append(f"label:{labels[0]}") + else: + label_expr = " OR ".join(f"label:{lbl}" for lbl in labels) + parts.append(f"({label_expr})") + + # Senders — each prefixed with "from:". + senders: list[str] = cfg.get("senders", []) + for sender in senders: + parts.append(f"from:{sender}") + + # Date range. + date_range: dict = cfg.get("date_range", {}) + from_str: str | None = date_range.get("from") + to_str: str | None = date_range.get("to") + + # Determine effective "from" date: most recent of filter_config.date_range.from and since. + effective_since: datetime | None = since + if from_str: + try: + cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00")) + if cfg_since.tzinfo is None: + cfg_since = cfg_since.replace(tzinfo=timezone.utc) + if effective_since is None or cfg_since > effective_since: + effective_since = cfg_since + except ValueError: + logger.warning("gmail: invalid date_range.from %r — ignoring", from_str) + + if effective_since: + parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}") + + if to_str: + try: + to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00")) + parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}") + except ValueError: + logger.warning("gmail: invalid date_range.to %r — ignoring", to_str) + + return " ".join(parts) + + +def _strip_html(raw_html: str) -> str: + """Remove HTML tags and decode entities to get plain text.""" + no_tags = re.sub(r"<[^>]+>", " ", raw_html) + decoded = html.unescape(no_tags) + return re.sub(r"\s+", " ", decoded).strip() + + +def _parse_body(payload: dict[str, Any]) -> str: + """Recursively extract the plain-text body from a Gmail message payload. + + Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags). + Returns an empty string if no body can be extracted. + """ + mime_type: str = payload.get("mimeType", "") + body: dict = payload.get("body", {}) + parts: list[dict] = payload.get("parts", []) + + if mime_type == "text/plain": + data = body.get("data", "") + if data: + return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace") + return "" + + if mime_type == "text/html": + data = body.get("data", "") + if data: + raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace") + return _strip_html(raw) + return "" + + # Multipart — prefer text/plain part, fall back to text/html. + plain_fallback = "" + for part in parts: + part_mime = part.get("mimeType", "") + if part_mime == "text/plain": + return _parse_body(part) + if part_mime == "text/html" and not plain_fallback: + plain_fallback = _parse_body(part) + if part_mime.startswith("multipart/"): + nested = _parse_body(part) + if nested: + return nested + return plain_fallback + + +def _parse_date(raw: str) -> datetime: + """Parse an RFC 2822 email date header into a UTC ``datetime``.""" + try: + parsed = email.utils.parsedate_to_datetime(raw) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + except Exception: + return datetime.now(timezone.utc) + + +class GmailClient: + """Fetch email messages from a Gmail account via the Gmail REST API. + + Parameters + ---------- + credentials_info: + Decrypted OAuth2 credential dict. Must contain at minimum + ``token`` (access token) or ``refresh_token`` + ``token_uri`` + + ``client_id`` + ``client_secret``. + """ + + def __init__(self, credentials_info: dict[str, Any]) -> None: + from google.oauth2.credentials import Credentials + + self._credentials_info = credentials_info + expiry_str: str | None = credentials_info.get("expiry") + expiry: datetime | None = None + if expiry_str: + try: + expiry = datetime.fromisoformat( + expiry_str.replace("Z", "+00:00") + ).replace(tzinfo=timezone.utc) + except ValueError: + pass + + self._credentials = Credentials( + token=credentials_info.get("token"), + refresh_token=credentials_info.get("refresh_token"), + token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"), + client_id=credentials_info.get("client_id"), + client_secret=credentials_info.get("client_secret"), + scopes=credentials_info.get("scopes"), + expiry=expiry, + ) + + # ── Public API ───────────────────────────────────────────────────────── + + async def fetch_messages( + self, + filter_config: dict[str, Any] | None = None, + since: datetime | None = None, + ) -> list[EmailMessage]: + """Return up to ``_MAX_MESSAGES`` emails matching *filter_config*. + + Runs the synchronous Google API calls inside ``asyncio.to_thread()`` + to avoid blocking the async event loop. + + Token refresh is performed automatically when the access token has + expired. After the call, ``self.refreshed_credentials`` may be + consulted to detect whether new credentials should be persisted. + """ + query = _build_gmail_query(filter_config, since) + logger.debug("gmail: executing search query %r", query) + return await asyncio.to_thread(self._fetch_sync, query) + + @property + def refreshed_credentials(self) -> dict[str, Any] | None: + """Return updated credential dict if the access token was refreshed. + + If the credentials were refreshed during ``fetch_messages()``, returns + a new dict that should be re-encrypted and written back to the DB. + Returns ``None`` if no refresh occurred. + """ + creds = self._credentials + if not creds.valid and creds.expired: + return None + # Check whether the token changed from what was stored. + if creds.token != self._credentials_info.get("token"): + result = { + "token": creds.token, + "refresh_token": creds.refresh_token, + "token_uri": creds.token_uri, + "client_id": creds.client_id, + "client_secret": creds.client_secret, + "scopes": list(creds.scopes or []), + } + if creds.expiry: + result["expiry"] = creds.expiry.isoformat() + return result + return None + + # ── Internal sync worker ─────────────────────────────────────────────── + + def _fetch_sync(self, query: str) -> list[EmailMessage]: + """Synchronous worker — called inside ``asyncio.to_thread()``.""" + import googleapiclient.discovery + import googleapiclient.errors + from google.auth.transport.requests import Request + + # Refresh token if needed before building the service. + if self._credentials.expired and self._credentials.refresh_token: + try: + self._credentials.refresh(Request()) + except Exception as exc: + raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc + + service = googleapiclient.discovery.build( + "gmail", "v1", credentials=self._credentials, cache_discovery=False + ) + user_api = service.users() # type: ignore[attr-defined] + + # ── List matching message IDs ────────────────────────────────────── + ids: list[str] = [] + page_token: str | None = None + while len(ids) < _MAX_MESSAGES: + batch_size = min(100, _MAX_MESSAGES - len(ids)) + kwargs: dict[str, Any] = { + "userId": "me", + "maxResults": batch_size, + } + if query: + kwargs["q"] = query + if page_token: + kwargs["pageToken"] = page_token + + try: + resp = user_api.messages().list(**kwargs).execute() + except googleapiclient.errors.HttpError as exc: + raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc + + for msg in resp.get("messages", []): + ids.append(msg["id"]) + + page_token = resp.get("nextPageToken") + if not page_token: + break + + if not ids: + logger.debug("gmail: no messages matched query %r", query) + return [] + + logger.info("gmail: fetching %d message(s)", len(ids)) + + # ── Fetch individual message details ────────────────────────────── + messages: list[EmailMessage] = [] + for msg_id in ids: + try: + msg = user_api.messages().get( + userId="me", id=msg_id, format="full" + ).execute() + + headers: dict[str, str] = { + h["name"].lower(): h["value"] + for h in msg.get("payload", {}).get("headers", []) + } + subject = headers.get("subject", "(no subject)") + sender = headers.get("from", "unknown") + date_raw = headers.get("date", "") + date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc) + + body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE] + labels = msg.get("labelIds", []) + + messages.append(EmailMessage( + id=msg_id, + subject=subject, + sender=sender, + body_text=body_text, + date=date, + labels=labels, + )) + except googleapiclient.errors.HttpError as exc: + logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc) + except Exception as exc: + logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc) + + logger.info("gmail: returned %d message(s)", len(messages)) + return messages diff --git a/app/integrations/ms_graph.py b/app/integrations/ms_graph.py new file mode 100644 index 0000000..14ed001 --- /dev/null +++ b/app/integrations/ms_graph.py @@ -0,0 +1,352 @@ +"""Microsoft Graph API client for Outlook and Teams cloud agent integration. + +Handles two data sources: + +* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls + ``/me/messages`` with an OData ``$filter`` built from ``filter_config``. +* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls + ``/me/chats/getAllMessages`` filtered by date. + +Authentication uses MSAL ``PublicClientApplication`` to acquire a token +from a stored refresh token. The ``httpx.AsyncClient`` (already a project +dependency) is used for all API calls. + +Credential dict shape (Microsoft OAuth2 / MSAL): + { + "access_token": "", + "refresh_token": "", + "token_type": "Bearer", + "scope": "Mail.Read ChannelMessage.Read.All offline_access", + "expires_in": 3600 + } +""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime, timedelta, timezone +from typing import Any + +import httpx + +from app.config.settings import settings +from app.integrations import ChatMessage, EmailMessage + +logger = logging.getLogger(__name__) + +_GRAPH_BASE = "https://graph.microsoft.com/v1.0" + +# Max items fetched per run. +_MAX_EMAILS = 200 +_MAX_MESSAGES = 200 + +# Max characters of body forwarded to the LLM. +_BODY_TRUNCATE = 8_000 + + +def _strip_html(raw: str) -> str: + """Strip HTML tags and collapse whitespace.""" + no_tags = re.sub(r"<[^>]+>", " ", raw) + import html as _html + decoded = _html.unescape(no_tags) + return re.sub(r"\s+", " ", decoded).strip() + + +def _odata_datetime(dt: datetime) -> str: + """Format a datetime as an OData datetime literal (UTC, ISO 8601).""" + utc = dt.astimezone(timezone.utc) + return utc.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def _build_email_filter( + filter_config: dict[str, Any] | None, + since: datetime | None, +) -> str: + """Build an OData ``$filter`` expression for the ``/me/messages`` endpoint. + + Supported ``filter_config`` keys: + senders (list[str]): Sender email addresses. + date_range (dict): ``{from: "", to: ""}`` + folders (list[str]): Folder display names (not directly filterable + via OData, so ignored here — callers iterate + folder IDs separately if needed; listed for + completeness). + + A hard ``since`` date always overrides ``date_range.from`` when it is + earlier. + """ + clauses: list[str] = [] + cfg = filter_config or {} + + # Senders. + senders: list[str] = cfg.get("senders", []) + if senders: + sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders] + clauses.append("(" + " or ".join(sender_clauses) + ")") + + # Date range. + date_range: dict = cfg.get("date_range", {}) + from_str: str | None = date_range.get("from") + + effective_since: datetime | None = since + if from_str: + try: + cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00")) + if cfg_since.tzinfo is None: + cfg_since = cfg_since.replace(tzinfo=timezone.utc) + if effective_since is None or cfg_since > effective_since: + effective_since = cfg_since + except ValueError: + logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str) + + if effective_since: + clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}") + + to_str: str | None = date_range.get("to") + if to_str: + try: + to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00")) + if to_dt.tzinfo is None: + to_dt = to_dt.replace(tzinfo=timezone.utc) + clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}") + except ValueError: + logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str) + + return " and ".join(clauses) + + +class MSGraphClient: + """Fetch emails and Teams messages via the Microsoft Graph REST API. + + Parameters + ---------- + credentials_info: + Decrypted MSAL credential dict. + """ + + def __init__(self, credentials_info: dict[str, Any]) -> None: + self._credentials_info = credentials_info + self._access_token: str = credentials_info.get("access_token", "") + self._original_access_token: str = self._access_token + self._refresh_token: str | None = credentials_info.get("refresh_token") + + # ── Token management ─────────────────────────────────────────────────── + + def _auth_headers(self) -> dict[str, str]: + return {"Authorization": f"Bearer {self._access_token}"} + + async def _refresh_access_token(self) -> None: + """Use MSAL to exchange the refresh token for a fresh access token. + + Updates ``self._access_token`` and ``self._credentials_info`` in-place. + + Raises: + RuntimeError: MSAL reports an auth error. + """ + import msal + + app = msal.ConfidentialClientApplication( + client_id=settings.MS_CLIENT_ID, + client_credential=settings.MS_CLIENT_SECRET, + authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}", + ) + scopes: list[str] = self._credentials_info.get("scope", "").split() + if not scopes: + scopes = ["https://graph.microsoft.com/.default"] + + result = app.acquire_token_by_refresh_token( + self._refresh_token, + scopes=scopes, + ) + if "access_token" not in result: + error = result.get("error_description", result.get("error", "unknown")) + raise RuntimeError(f"MS Graph token refresh failed: {error}") + + self._access_token = result["access_token"] + # MSAL may issue a new refresh token. + if "refresh_token" in result: + self._refresh_token = result["refresh_token"] + self._credentials_info["refresh_token"] = result["refresh_token"] + self._credentials_info["access_token"] = self._access_token + + @property + def refreshed_credentials(self) -> dict[str, Any] | None: + """Return updated credential dict if the access token was refreshed. + + Returns ``None`` if no change was made. + """ + if self._access_token != self._original_access_token: + return {**self._credentials_info, "access_token": self._access_token} + return None + + # ── HTTP helpers ─────────────────────────────────────────────────────── + + async def _get( + self, + client: httpx.AsyncClient, + url: str, + params: dict[str, Any] | None = None, + *, + retry_on_401: bool = True, + ) -> dict[str, Any]: + """GET *url* with auth; refresh token on 401 and retry once.""" + resp = await client.get(url, params=params, headers=self._auth_headers()) + if resp.status_code == 401 and retry_on_401 and self._refresh_token: + logger.debug("ms_graph: 401 on %s — refreshing token", url) + await self._refresh_access_token() + resp = await client.get(url, params=params, headers=self._auth_headers()) + if resp.status_code == 429: + raise RuntimeError("MS Graph rate limit hit (429). Try again later.") + resp.raise_for_status() + return resp.json() + + # ── Public API ───────────────────────────────────────────────────────── + + async def fetch_emails( + self, + filter_config: dict[str, Any] | None = None, + since: datetime | None = None, + ) -> list[EmailMessage]: + """Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*. + + Parameters + ---------- + filter_config: + Optional dict with ``senders``, ``date_range``, ``folders`` keys. + since: + Hard lower-bound on email date (from last agent run). + """ + odata_filter = _build_email_filter(filter_config, since) + params: dict[str, Any] = { + "$top": 50, + "$select": "id,subject,from,receivedDateTime,body,bodyPreview", + "$orderby": "receivedDateTime desc", + } + if odata_filter: + params["$filter"] = odata_filter + + emails: list[EmailMessage] = [] + url = f"{_GRAPH_BASE}/me/messages" + + async with httpx.AsyncClient(timeout=30.0) as client: + while url and len(emails) < _MAX_EMAILS: + data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None) + for item in data.get("value", []): + emails.append(self._parse_email(item)) + if len(emails) >= _MAX_EMAILS: + break + url = data.get("@odata.nextLink", "") + params = {} # nextLink already contains encoded params. + + logger.info("ms_graph: fetched %d Outlook email(s)", len(emails)) + return emails + + async def fetch_messages( + self, + filter_config: dict[str, Any] | None = None, + since: datetime | None = None, + ) -> list[ChatMessage]: + """Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*. + + Fetches from ``/me/chats/getAllMessages`` (personal + group chats). + The ``filter_config.channels`` key is checked as a text-filter on + the channel name post-fetch (the API doesn't support channel OData + filter directly on ``getAllMessages``). + """ + cfg = filter_config or {} + channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])] + params: dict[str, Any] = {"$top": 50} + if since: + params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}" + + messages: list[ChatMessage] = [] + url = f"{_GRAPH_BASE}/me/chats/getAllMessages" + + async with httpx.AsyncClient(timeout=30.0) as client: + while url and len(messages) < _MAX_MESSAGES: + try: + data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None) + except httpx.HTTPStatusError as exc: + # getAllMessages requires specific licensing; degrade gracefully. + if exc.response.status_code in (403, 404): + logger.warning( + "ms_graph: /me/chats/getAllMessages not available (%d) — " + "check Teams license or permissions", + exc.response.status_code, + ) + break + raise + + for item in data.get("value", []): + msg = self._parse_teams_message(item) + if channel_filter and msg.channel: + if not any(c in msg.channel.lower() for c in channel_filter): + continue + messages.append(msg) + if len(messages) >= _MAX_MESSAGES: + break + url = data.get("@odata.nextLink", "") + params = {} + + logger.info("ms_graph: fetched %d Teams message(s)", len(messages)) + return messages + + # ── Parsers ──────────────────────────────────────────────────────────── + + @staticmethod + def _parse_email(item: dict[str, Any]) -> EmailMessage: + subject: str = item.get("subject", "(no subject)") or "(no subject)" + sender_block = item.get("from", {}) or {} + sender_addr = ( + (sender_block.get("emailAddress") or {}).get("address", "unknown") + ) + date_str: str = item.get("receivedDateTime", "") + try: + date = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except Exception: + date = datetime.now(timezone.utc) + + body_block = item.get("body", {}) or {} + content_type: str = body_block.get("contentType", "text") + raw_body: str = body_block.get("content", "") + if content_type == "html": + body_text = _strip_html(raw_body) + else: + body_text = raw_body or item.get("bodyPreview", "") + body_text = body_text[:_BODY_TRUNCATE] + + return EmailMessage( + id=item.get("id", ""), + subject=subject, + sender=sender_addr, + body_text=body_text, + date=date, + ) + + @staticmethod + def _parse_teams_message(item: dict[str, Any]) -> ChatMessage: + msg_id: str = item.get("id", "") + sender_block = (item.get("from") or {}).get("user") or {} + sender: str = sender_block.get("displayName", "unknown") + channel: str | None = (item.get("channelIdentity") or {}).get("channelId") + + date_str: str = item.get("createdDateTime", "") + try: + date = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except Exception: + date = datetime.now(timezone.utc) + + body_block = item.get("body", {}) or {} + content_type: str = body_block.get("contentType", "text") + raw_content: str = body_block.get("content", "") + content = _strip_html(raw_content) if content_type == "html" else raw_content + content = content[:_BODY_TRUNCATE] + + return ChatMessage( + id=msg_id, + content=content, + sender=sender, + channel=channel, + date=date, + ) diff --git a/docker-compose.yml b/docker-compose.yml index 0d40152..07b33c6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,9 @@ services: required: false environment: DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva + GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot + volumes: + - copilot_tokens:/root/.config/litellm/github_copilot depends_on: db: condition: service_healthy @@ -66,3 +69,4 @@ volumes: postgres_data: minio_data: qdrant_data: + copilot_tokens: diff --git a/requirements.txt b/requirements.txt index 0650450..7e2fbcd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,10 @@ moto[s3]>=5.0.0 pinecone>=5.0.0 qdrant-client>=1.7.0 croniter>=3.0.0 +google-api-python-client>=2.130.0 +google-auth>=2.29.0 +google-auth-oauthlib>=1.2.0 +google-auth-httplib2>=0.2.0 +msal>=1.28.0 +cryptography>=42.0.0 ruff>=0.8.0 diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 46b748d..d1d58d5 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -455,21 +455,232 @@ async def test_run_local_agent_llm_extraction_error(): @pytest.mark.asyncio -async def test_run_cloud_agent_stub_returns_error(): - """Cloud agent stub immediately marks run as error with informative message.""" +async def test_run_cloud_agent_device_offline(): + """Cloud agent aborts immediately when no device is connected.""" config = _make_cloud_config() run_log = _make_run_log(config.id, agent_type="cloud") + mgr = DeviceConnectionManager() # empty — no devices registered + + with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: + await run_cloud_agent(_FREE_UID, config, run_log, mgr) + + mock_finalize.assert_called_once() + _, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" + assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"]) + + +@pytest.mark.asyncio +async def test_run_cloud_agent_no_oauth_token(): + """Cloud agent errors when no OAuth token is stored.""" + config = _make_cloud_config() + config.oauth_token_encrypted = None + run_log = _make_run_log(config.id, agent_type="cloud") mgr = _make_manager() with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: await run_cloud_agent(_FREE_UID, config, run_log, mgr) - mock_finalize.assert_called_once() - _args, kwargs = mock_finalize.call_args + _, kwargs = mock_finalize.call_args assert kwargs["status"] == "error" - assert len(kwargs["errors"]) == 1 - assert "gmail" in kwargs["errors"][0].lower() - assert "3.6" in kwargs["errors"][0] + assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"]) + + +@pytest.mark.asyncio +async def test_run_cloud_agent_token_decrypt_failure(): + """Cloud agent errors gracefully when the stored token cannot be decrypted.""" + config = _make_cloud_config() + config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext" + run_log = _make_run_log(config.id, agent_type="cloud") + mgr = _make_manager() + + from cryptography.fernet import Fernet as _Fernet + valid_key = _Fernet.generate_key().decode() + + with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \ + patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = valid_key + await run_cloud_agent(_FREE_UID, config, run_log, mgr) + + _, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" + assert any("decrypt" in e.lower() for e in kwargs["errors"]) + + +@pytest.mark.asyncio +async def test_run_cloud_agent_happy_path_gmail(): + """Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success.""" + from app.integrations import EmailMessage, encrypt_token + from cryptography.fernet import Fernet as _Fernet + + fernet_key = _Fernet.generate_key().decode() + credentials = { + "token": "access_abc", + "refresh_token": "refresh_xyz", + "token_uri": "https://oauth2.googleapis.com/token", + "client_id": "cid", + "client_secret": "csec", + } + + config = _make_cloud_config() + config.provider = "gmail" + config.prompt_template = "Extract tasks from this email." + config.data_types = ["tasks"] + + with patch("app.integrations.settings") as ms: + ms.OAUTH_ENCRYPTION_KEY = fernet_key + config.oauth_token_encrypted = encrypt_token(credentials) + + run_log = _make_run_log(config.id, agent_type="cloud") + mgr = _make_manager() + + sample_email = EmailMessage( + id="msg001", + subject="Action required", + sender="boss@company.com", + body_text="Please fix the bug by Friday.", + date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc), + ) + + extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}] + + with patch("app.integrations.settings") as mock_int_settings, \ + patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \ + patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \ + patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \ + patch("app.core.agent_runner.async_session"): + mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key + + mock_gmail = AsyncMock() + mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email]) + mock_gmail.refreshed_credentials = None + + with patch("app.integrations.decrypt_token", return_value=credentials), \ + patch("app.integrations.get_provider", return_value=mock_gmail): + await run_cloud_agent(_FREE_UID, config, run_log, mgr) + + mock_extract.assert_called_once() + mock_insert.assert_called_once() + _, kwargs = mock_finalize.call_args + assert kwargs["status"] == "success" + assert kwargs["items_processed"] == 1 + assert kwargs["items_created"] == 1 + assert kwargs["config_type"] == "cloud" + + +@pytest.mark.asyncio +async def test_run_cloud_agent_provider_fetch_error(): + """Cloud agent records error status when provider fetch raises RuntimeError.""" + credentials = {"token": "abc"} + config = _make_cloud_config() + config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached + config.prompt_template = "Extract tasks." + config.data_types = ["tasks"] + run_log = _make_run_log(config.id, agent_type="cloud") + mgr = _make_manager() + + mock_provider = AsyncMock() + mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded")) + mock_provider.refreshed_credentials = None + + with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \ + patch("app.integrations.decrypt_token", return_value=credentials), \ + patch("app.integrations.get_provider", return_value=mock_provider), \ + patch("app.core.agent_runner.async_session"): + await run_cloud_agent(_FREE_UID, config, run_log, mgr) + + _, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" + assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"]) + + +@pytest.mark.asyncio +async def test_run_cloud_agent_refreshed_token_persisted(): + """When the provider refreshes its token, the new ciphertext is written to DB.""" + from app.integrations import EmailMessage, encrypt_token + from cryptography.fernet import Fernet as _Fernet + + fernet_key = _Fernet.generate_key().decode() + credentials = {"token": "old_token", "refresh_token": "rt_old"} + fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"} + + config = _make_cloud_config() + config.prompt_template = "Extract tasks." + config.data_types = ["tasks"] + + with patch("app.integrations.settings") as ms: + ms.OAUTH_ENCRYPTION_KEY = fernet_key + config.oauth_token_encrypted = encrypt_token(credentials) + + run_log = _make_run_log(config.id, agent_type="cloud") + mgr = _make_manager() + + mock_provider = AsyncMock() + mock_provider.fetch_messages = AsyncMock(return_value=[]) + mock_provider.refreshed_credentials = fresh_credentials # token was refreshed + + # Track DB writes via mock async_session. + mock_cfg_row = MagicMock() + mock_cfg_row.oauth_token_encrypted = None + + mock_db = AsyncMock() + mock_db.__aenter__ = AsyncMock(return_value=mock_db) + mock_db.__aexit__ = AsyncMock(return_value=False) + mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row) + cfg_result = MagicMock() + cfg_result.scalar_one_or_none.return_value = mock_cfg_row + mock_db.execute = AsyncMock(return_value=cfg_result) + mock_db.commit = AsyncMock() + + with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \ + patch("app.integrations.decrypt_token", return_value=credentials), \ + patch("app.integrations.get_provider", return_value=mock_provider), \ + patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \ + patch("app.core.agent_runner.async_session", return_value=mock_db), \ + patch("app.integrations.settings") as mock_int_settings: + mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key + await run_cloud_agent(_FREE_UID, config, run_log, mgr) + + # The new encrypted token should have been written to the config row. + mock_encrypt.assert_called_once_with(fresh_credentials) + assert mock_cfg_row.oauth_token_encrypted == "new_encrypted" + + +@pytest.mark.asyncio +async def test_finalize_run_updates_cloud_config_last_run_at(): + """_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at.""" + from app.core.agent_runner import _finalize_run + + run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud") + run_log.id = str(uuid.uuid4()) + + mock_cfg = MagicMock() + mock_cfg.last_run_at = None + + cfg_result = MagicMock() + cfg_result.scalar_one_or_none.return_value = mock_cfg + + mock_db = AsyncMock() + mock_db.__aenter__ = AsyncMock(return_value=mock_db) + mock_db.__aexit__ = AsyncMock(return_value=False) + mock_db.merge = AsyncMock(return_value=run_log) + mock_db.execute = AsyncMock(return_value=cfg_result) + mock_db.commit = AsyncMock() + + config_id = str(uuid.uuid4()) + + with patch("app.core.agent_runner.async_session", return_value=mock_db): + await _finalize_run( + run_log, + status="success", + update_config_last_run=True, + config_id=config_id, + config_type="cloud", + ) + + # CloudAgentConfig.last_run_at should have been set. + assert mock_cfg.last_run_at is not None + mock_db.commit.assert_called() # --------------------------------------------------------------------------- diff --git a/tests/test_integrations.py b/tests/test_integrations.py new file mode 100644 index 0000000..79abccd --- /dev/null +++ b/tests/test_integrations.py @@ -0,0 +1,729 @@ +"""Tests for Step 3.6: cloud provider integration clients. + +Coverage: + Unit \u2014 app/integrations/__init__.py: + - encrypt_token / decrypt_token round-trip + - decrypt_token raises ValueError on invalid ciphertext + - encrypt_token raises ValueError on empty/non-dict input + - _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set + - get_provider returns GmailClient for 'gmail' + - get_provider returns MSGraphClient for 'outlook' and 'teams' + - get_provider raises ValueError for unknown provider + + Unit \u2014 app/integrations/gmail.py: + - _build_gmail_query with no filter returns empty string + - _build_gmail_query with labels builds label: expr + - _build_gmail_query with senders builds from: expr + - _build_gmail_query with date_range builds after:/before: exprs + - _build_gmail_query since overrides date_range.from when more recent + - _build_gmail_query date_range.from overrides since when more recent + - _parse_body extracts text/plain part + - _parse_body extracts text/html part (stripped) + - _parse_body recurses into multipart, prefers text/plain + - GmailClient.fetch_messages: happy path with mocked service + - GmailClient.fetch_messages: no messages returns empty list + - GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError + - GmailClient.refreshed_credentials: None when token unchanged + - GmailClient.refreshed_credentials: returns dict when token changes + + Unit \u2014 app/integrations/ms_graph.py: + - _build_email_filter with no filter returns empty string + - _build_email_filter with senders builds OData from clause + - _build_email_filter with since builds receivedDateTime ge clause + - MSGraphClient.fetch_emails: happy path with mocked httpx + - MSGraphClient.fetch_emails: 401 triggers token refresh and retries + - MSGraphClient.fetch_messages: happy path with mocked httpx + - MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully + - MSGraphClient.refreshed_credentials: None when token unchanged + - MSGraphClient._refresh_access_token: MSAL error raises RuntimeError +""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch + +import pytest + +from app.integrations import ( + ChatMessage, + EmailMessage, + decrypt_token, + encrypt_token, + get_provider, +) + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# Helpers +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + +_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q=" +# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length, +# so we generate a proper one below) + +from cryptography.fernet import Fernet as _Fernet # noqa: E402 + +_VALID_KEY = _Fernet.generate_key().decode("utf-8") + +_TOKEN_DICT = { + "token": "access_abc", + "refresh_token": "refresh_xyz", + "token_uri": "https://oauth2.googleapis.com/token", + "client_id": "client_id_123", + "client_secret": "client_secret_456", + "scopes": ["https://www.googleapis.com/auth/gmail.readonly"], +} + +_MS_TOKEN_DICT = { + "access_token": "ms_access_abc", + "refresh_token": "ms_refresh_xyz", + "token_type": "Bearer", + "scope": "Mail.Read offline_access", +} + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# encrypt_token / decrypt_token +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +class TestTokenEncryption: + """encrypt_token / decrypt_token round-trip tests.""" + + def test_round_trip(self): + with patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY + encrypted = encrypt_token(_TOKEN_DICT) + assert isinstance(encrypted, str) + assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext + recovered = decrypt_token(encrypted) + assert recovered == _TOKEN_DICT + + def test_decrypt_invalid_ciphertext_raises_value_error(self): + with patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY + with pytest.raises(ValueError, match="Failed to decrypt"): + decrypt_token("this-is-not-valid-fernet-ciphertext") + + def test_decrypt_wrong_key_raises_value_error(self): + """Decrypting with a different key must fail with ValueError.""" + other_key = _Fernet.generate_key().decode("utf-8") + with patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY + encrypted = encrypt_token(_TOKEN_DICT) + with patch("app.integrations.settings") as mock_settings2: + mock_settings2.OAUTH_ENCRYPTION_KEY = other_key + with pytest.raises(ValueError, match="Failed to decrypt"): + decrypt_token(encrypted) + + def test_encrypt_empty_dict_raises_value_error(self): + with patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY + with pytest.raises(ValueError, match="non-empty dict"): + encrypt_token({}) + + def test_encrypt_non_dict_raises_value_error(self): + with patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY + with pytest.raises(ValueError, match="non-empty dict"): + encrypt_token("not-a-dict") # type: ignore[arg-type] + + def test_missing_key_raises_runtime_error(self): + with patch("app.integrations.settings") as mock_settings: + mock_settings.OAUTH_ENCRYPTION_KEY = "" + with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"): + encrypt_token(_TOKEN_DICT) + + def test_email_message_as_text(self): + msg = EmailMessage( + id="m1", + subject="Hello", + sender="alice@example.com", + body_text="Test body", + date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc), + ) + text = msg.as_text + assert "From: alice@example.com" in text + assert "Subject: Hello" in text + assert "Test body" in text + + def test_chat_message_as_text(self): + msg = ChatMessage( + id="c1", + content="Buy milk", + sender="bob", + channel="general", + date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc), + ) + text = msg.as_text + assert "From: bob" in text + assert "channel: general" in text + assert "Buy milk" in text + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# get_provider factory +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +class TestGetProvider: + def test_gmail_returns_gmail_client(self): + from app.integrations.gmail import GmailClient + + client = get_provider("gmail", _TOKEN_DICT) + assert isinstance(client, GmailClient) + + def test_outlook_returns_ms_graph_client(self): + from app.integrations.ms_graph import MSGraphClient + + client = get_provider("outlook", _MS_TOKEN_DICT) + assert isinstance(client, MSGraphClient) + + def test_teams_returns_ms_graph_client(self): + from app.integrations.ms_graph import MSGraphClient + + client = get_provider("teams", _MS_TOKEN_DICT) + assert isinstance(client, MSGraphClient) + + def test_unknown_provider_raises_value_error(self): + with pytest.raises(ValueError, match="Unknown cloud provider"): + get_provider("slack", {}) + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# Gmail client \u2014 query builder +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +class TestBuildGmailQuery: + """Unit tests for gmail._build_gmail_query.""" + + def setup_method(self): + from app.integrations.gmail import _build_gmail_query + self._fn = _build_gmail_query + + def test_empty_returns_empty_string(self): + assert self._fn(None, None) == "" + + def test_single_label(self): + q = self._fn({"labels": ["INBOX"]}, None) + assert "label:INBOX" in q + + def test_multiple_labels_joined_with_or(self): + q = self._fn({"labels": ["INBOX", "work"]}, None) + assert "label:INBOX OR label:work" in q + + def test_senders(self): + q = self._fn({"senders": ["alice@example.com"]}, None) + assert "from:alice@example.com" in q + + def test_date_range_from(self): + q = self._fn({"date_range": {"from": "2025-01-15"}}, None) + assert "after:2025/01/15" in q + + def test_date_range_to(self): + q = self._fn({"date_range": {"to": "2025-03-01"}}, None) + assert "before:2025/03/01" in q + + def test_since_overrides_earlier_date_range_from(self): + """since=Feb is more recent than date_range.from=Jan, so after: should be Feb.""" + since = datetime(2025, 2, 1, tzinfo=timezone.utc) + q = self._fn({"date_range": {"from": "2025-01-01"}}, since) + assert "after:2025/02/01" in q + assert "after:2025/01/01" not in q + + def test_date_range_from_overrides_earlier_since(self): + """date_range.from=Feb is more recent than since=Jan, so after: should be Feb.""" + since = datetime(2025, 1, 1, tzinfo=timezone.utc) + q = self._fn({"date_range": {"from": "2025-02-01"}}, since) + assert "after:2025/02/01" in q + + def test_invalid_date_ignored(self): + """An invalid date string in filter_config must not raise, just be skipped.""" + q = self._fn({"date_range": {"from": "not-a-date"}}, None) + assert "after:" not in q + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# Gmail client \u2014 body parsing +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +class TestParseBody: + """Unit tests for gmail._parse_body.""" + + def setup_method(self): + from app.integrations.gmail import _parse_body + self._fn = _parse_body + + def _encode(self, text: str) -> str: + import base64 + return base64.urlsafe_b64encode(text.encode()).decode() + + def test_text_plain_extracted(self): + payload = { + "mimeType": "text/plain", + "body": {"data": self._encode("Hello world")}, + } + assert self._fn(payload) == "Hello world" + + def test_text_html_stripped(self): + payload = { + "mimeType": "text/html", + "body": {"data": self._encode("

Hello world

")}, + } + result = self._fn(payload) + assert "Hello" in result + assert "

" not in result + + def test_multipart_prefers_plain_over_html(self): + plain_data = self._encode("Plain text") + html_data = self._encode("

HTML text

") + payload = { + "mimeType": "multipart/alternative", + "body": {}, + "parts": [ + {"mimeType": "text/html", "body": {"data": html_data}}, + {"mimeType": "text/plain", "body": {"data": plain_data}}, + ], + } + result = self._fn(payload) + assert result == "Plain text" + + def test_empty_payload_returns_empty_string(self): + assert self._fn({}) == "" + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# GmailClient.fetch_messages +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +def _make_gmail_message( + msg_id: str = "msg001", + subject: str = "Test email", + sender: str = "alice@example.com", + body_text: str = "Hello world", + date: str = "Mon, 01 Jan 2025 10:00:00 +0000", +) -> dict: + """Build a minimal Gmail API message response dict.""" + import base64 + body_data = base64.urlsafe_b64encode(body_text.encode()).decode() + return { + "id": msg_id, + "labelIds": ["INBOX"], + "payload": { + "mimeType": "text/plain", + "headers": [ + {"name": "Subject", "value": subject}, + {"name": "From", "value": sender}, + {"name": "Date", "value": date}, + ], + "body": {"data": body_data}, + }, + } + + +class TestGmailClientFetchMessages: + """GmailClient.fetch_messages tests with mocked Google API.""" + + def _make_client(self) -> "GmailClient": + from app.integrations.gmail import GmailClient + return GmailClient(_TOKEN_DICT) + + @pytest.mark.asyncio + async def test_happy_path_returns_email_messages(self): + client = self._make_client() + msg = _make_gmail_message() + + mock_service = MagicMock() + mock_users = mock_service.users.return_value + mock_messages = mock_users.messages.return_value + mock_messages.list.return_value.execute.return_value = { + "messages": [{"id": "msg001"}] + } + mock_messages.get.return_value.execute.return_value = msg + + with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread: + # Simulate to_thread running the sync function and returning results. + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + mock_thread.side_effect = fake_to_thread + + with patch("googleapiclient.discovery.build", return_value=mock_service), \ + patch("google.auth.transport.requests.Request"), \ + patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False): + results = await client.fetch_messages() + + assert len(results) == 1 + assert results[0].subject == "Test email" + assert results[0].sender == "alice@example.com" + assert results[0].body_text == "Hello world" + + @pytest.mark.asyncio + async def test_no_messages_returns_empty_list(self): + client = self._make_client() + + mock_service = MagicMock() + mock_users = mock_service.users.return_value + mock_messages = mock_users.messages.return_value + mock_messages.list.return_value.execute.return_value = {"messages": []} + + with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread: + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + mock_thread.side_effect = fake_to_thread + + with patch("googleapiclient.discovery.build", return_value=mock_service), \ + patch("google.auth.transport.requests.Request"), \ + patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False): + results = await client.fetch_messages() + + assert results == [] + + @pytest.mark.asyncio + async def test_list_http_error_raises_runtime_error(self): + import googleapiclient.errors + client = self._make_client() + + mock_service = MagicMock() + mock_users = mock_service.users.return_value + mock_messages = mock_users.messages.return_value + mock_resp = MagicMock() + mock_resp.status = 403 + mock_resp.reason = "Forbidden" + mock_messages.list.return_value.execute.side_effect = ( + googleapiclient.errors.HttpError(mock_resp, b"Forbidden") + ) + + with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread: + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + mock_thread.side_effect = fake_to_thread + + with patch("googleapiclient.discovery.build", return_value=mock_service), \ + patch("google.auth.transport.requests.Request"), \ + patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False): + with pytest.raises(RuntimeError, match="Gmail messages.list failed"): + await client.fetch_messages() + + def test_refreshed_credentials_none_when_unchanged(self): + client = self._make_client() + # Token unchanged — should return None. + assert client.refreshed_credentials is None + + def test_refreshed_credentials_returns_dict_when_token_changes(self): + client = self._make_client() + # Simulate a token refresh by changing the access token on the credentials object. + client._credentials.token = "new_access_token_xyz" + refreshed = client.refreshed_credentials + assert refreshed is not None + assert refreshed["token"] == "new_access_token_xyz" + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# MS Graph client \u2014 email filter builder +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +class TestBuildEmailFilter: + """Unit tests for ms_graph._build_email_filter.""" + + def setup_method(self): + from app.integrations.ms_graph import _build_email_filter + self._fn = _build_email_filter + + def test_empty_returns_empty_string(self): + assert self._fn(None, None) == "" + + def test_single_sender(self): + result = self._fn({"senders": ["alice@example.com"]}, None) + assert "from/emailAddress/address eq 'alice@example.com'" in result + + def test_multiple_senders_joined_with_or(self): + result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None) + assert " or " in result + assert "a@x.com" in result + assert "b@x.com" in result + + def test_since_adds_received_date_ge_clause(self): + since = datetime(2025, 3, 1, tzinfo=timezone.utc) + result = self._fn(None, since) + assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result + + def test_date_range_to_adds_received_date_le_clause(self): + result = self._fn({"date_range": {"to": "2025-06-30"}}, None) + assert "receivedDateTime le" in result + + def test_since_overrides_earlier_date_range_from(self): + since = datetime(2025, 2, 1, tzinfo=timezone.utc) + result = self._fn({"date_range": {"from": "2025-01-01"}}, since) + assert "2025-02-01T00:00:00Z" in result + assert "2025-01-01" not in result + + def test_invalid_date_ignored(self): + result = self._fn({"date_range": {"from": "bad-date"}}, None) + assert "receivedDateTime" not in result + + +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 +# MSGraphClient.fetch_emails +# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 + + +def _make_graph_email( + msg_id: str = "email001", + subject: str = "Meeting tomorrow", + sender_address: str = "boss@company.com", + body_content: str = "Please prepare the report.", + received: str = "2025-06-01T10:00:00Z", +) -> dict: + """Build a minimal MS Graph message item dict.""" + return { + "id": msg_id, + "subject": subject, + "from": {"emailAddress": {"address": sender_address}}, + "receivedDateTime": received, + "body": {"contentType": "text", "content": body_content}, + "bodyPreview": body_content[:100], + } + + +def _make_graph_teams_message( + msg_id: str = "teams001", + content: str = "Stand-up at 9am", + sender: str = "alice", + channel_id: str = "chan001", + created: str = "2025-06-01T08:00:00Z", +) -> dict: + return { + "id": msg_id, + "body": {"contentType": "text", "content": content}, + "from": {"user": {"displayName": sender}}, + "channelIdentity": {"channelId": channel_id}, + "createdDateTime": created, + } + + +class TestMSGraphClientFetchEmails: + """MSGraphClient.fetch_emails tests with mocked httpx.""" + + def _make_client(self) -> "MSGraphClient": + from app.integrations.ms_graph import MSGraphClient + return MSGraphClient(_MS_TOKEN_DICT) + + @pytest.mark.asyncio + async def test_happy_path_returns_email_messages(self): + client = self._make_client() + graph_email = _make_graph_email() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": [graph_email]} + mock_response.raise_for_status = MagicMock() + + with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_response) + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + results = await client.fetch_emails() + + assert len(results) == 1 + assert results[0].subject == "Meeting tomorrow" + assert results[0].sender == "boss@company.com" + assert results[0].body_text == "Please prepare the report." + + @pytest.mark.asyncio + async def test_pagination_stops_at_max_emails(self): + """No nextLink in first page \u2014 only one batch returned.""" + client = self._make_client() + emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink + mock_response.raise_for_status = MagicMock() + + with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_response) + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + results = await client.fetch_emails() + + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_401_triggers_token_refresh_and_retries(self): + """On first 401, token refresh is attempted and the request retried.""" + from app.integrations.ms_graph import MSGraphClient + client = MSGraphClient(_MS_TOKEN_DICT) + + graph_email = _make_graph_email() + + response_401 = MagicMock() + response_401.status_code = 401 + + response_200 = MagicMock() + response_200.status_code = 200 + response_200.json.return_value = {"value": [graph_email]} + response_200.raise_for_status = MagicMock() + + call_count = 0 + + async def fake_get(url, params=None, headers=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return response_401 + return response_200 + + with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \ + patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh: + mock_http = AsyncMock() + mock_http.get = fake_get + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + results = await client.fetch_emails() + + mock_refresh.assert_called_once() + assert len(results) == 1 + + def test_refreshed_credentials_none_when_token_unchanged(self): + client = self._make_client() + assert client.refreshed_credentials is None + + def test_refreshed_credentials_returns_dict_when_token_changes(self): + client = self._make_client() + client._access_token = "new_token_abc" + assert client.refreshed_credentials is not None + assert client.refreshed_credentials["access_token"] == "new_token_abc" + + +class TestMSGraphClientFetchMessages: + """MSGraphClient.fetch_messages (Teams) tests.""" + + def _make_client(self) -> "MSGraphClient": + from app.integrations.ms_graph import MSGraphClient + return MSGraphClient(_MS_TOKEN_DICT) + + @pytest.mark.asyncio + async def test_happy_path_returns_chat_messages(self): + client = self._make_client() + teams_msg = _make_graph_teams_message() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": [teams_msg]} + mock_response.raise_for_status = MagicMock() + + with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_response) + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + results = await client.fetch_messages() + + assert len(results) == 1 + assert results[0].content == "Stand-up at 9am" + assert results[0].sender == "alice" + + @pytest.mark.asyncio + async def test_403_degrades_gracefully(self): + """getAllMessages returning 403 (license issue) returns empty list, no exception.""" + import httpx as _httpx + + client = self._make_client() + + error_response = MagicMock() + error_response.status_code = 403 + http_error = _httpx.HTTPStatusError( + "Forbidden", request=MagicMock(), response=error_response + ) + + with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls: + mock_http = AsyncMock() + mock_http.get = AsyncMock(side_effect=http_error) + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + results = await client.fetch_messages() + + assert results == [] + + @pytest.mark.asyncio + async def test_channel_filter_applied(self): + """Messages from non-matching channels are filtered out.""" + client = self._make_client() + matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today") + non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": [matching, non_matching]} + mock_response.raise_for_status = MagicMock() + + with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_response) + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + results = await client.fetch_messages( + filter_config={"channels": ["dev-channel"]} + ) + + assert len(results) == 1 + assert results[0].content == "Deploy today" + + +class TestMSGraphClientRefreshToken: + """MSGraphClient._refresh_access_token with mocked MSAL.""" + + @pytest.mark.asyncio + async def test_msal_error_raises_runtime_error(self): + from app.integrations.ms_graph import MSGraphClient + client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"}) + + mock_app = MagicMock() + mock_app.acquire_token_by_refresh_token.return_value = { + "error": "invalid_grant", + "error_description": "Refresh token expired", + } + + with patch("msal.ConfidentialClientApplication", return_value=mock_app), \ + patch("app.integrations.ms_graph.settings") as mock_settings: + mock_settings.MS_CLIENT_ID = "client_id" + mock_settings.MS_CLIENT_SECRET = "secret" + mock_settings.MS_TENANT_ID = "common" + with pytest.raises(RuntimeError, match="MS Graph token refresh failed"): + await client._refresh_access_token() + + @pytest.mark.asyncio + async def test_successful_refresh_updates_access_token(self): + from app.integrations.ms_graph import MSGraphClient + client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"}) + + mock_app = MagicMock() + mock_app.acquire_token_by_refresh_token.return_value = { + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + } + + with patch("msal.ConfidentialClientApplication", return_value=mock_app), \ + patch("app.integrations.ms_graph.settings") as mock_settings: + mock_settings.MS_CLIENT_ID = "client_id" + mock_settings.MS_CLIENT_SECRET = "secret" + mock_settings.MS_TENANT_ID = "common" + await client._refresh_access_token() + + assert client._access_token == "new_access_token" + assert client._refresh_token == "new_refresh_token"