diff --git a/tests/tools/test_microsoft_graph_auth.py b/tests/tools/test_microsoft_graph_auth.py new file mode 100644 index 0000000000..b969afe47c --- /dev/null +++ b/tests/tools/test_microsoft_graph_auth.py @@ -0,0 +1,149 @@ +"""Tests for tools/microsoft_graph_auth.py.""" + +from __future__ import annotations + +import httpx +import pytest + +from tools.microsoft_graph_auth import ( + DEFAULT_GRAPH_SCOPE, + GraphCredentials, + MicrosoftGraphConfigError, + MicrosoftGraphTokenError, + MicrosoftGraphTokenProvider, +) + + +class TestGraphCredentials: + def test_from_env_raises_for_missing_required_values(self): + with pytest.raises(MicrosoftGraphConfigError) as exc: + GraphCredentials.from_env({}) + assert "MSGRAPH_TENANT_ID" in str(exc.value) + assert "MSGRAPH_CLIENT_ID" in str(exc.value) + assert "MSGRAPH_CLIENT_SECRET" in str(exc.value) + + def test_from_env_optional_returns_none_when_not_configured(self): + assert GraphCredentials.from_env({}, required=False) is None + + def test_from_env_builds_normalized_credentials(self): + creds = GraphCredentials.from_env( + { + "MSGRAPH_TENANT_ID": "tenant-123", + "MSGRAPH_CLIENT_ID": "client-456", + "MSGRAPH_CLIENT_SECRET": "secret-789", + } + ) + assert creds is not None + assert creds.scope == DEFAULT_GRAPH_SCOPE + assert creds.token_url.endswith("/tenant-123/oauth2/v2.0/token") + + +@pytest.mark.anyio +class TestMicrosoftGraphTokenProvider: + async def test_reuses_cached_token_until_expiry(self): + calls: list[int] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + return httpx.Response( + 200, + json={ + "access_token": f"token-{len(calls)}", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + + provider = MicrosoftGraphTokenProvider( + GraphCredentials("tenant", "client", "secret"), + transport=httpx.MockTransport(handler), + ) + + first = await provider.get_access_token() + second = await provider.get_access_token() + + assert first == "token-1" + assert second == "token-1" + assert len(calls) == 1 + + async def test_refreshes_when_cached_token_is_expired(self): + calls: list[int] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + expires_in = 0 if len(calls) == 1 else 3600 + return httpx.Response( + 200, + json={ + "access_token": f"token-{len(calls)}", + "expires_in": expires_in, + "token_type": "Bearer", + }, + ) + + provider = MicrosoftGraphTokenProvider( + GraphCredentials("tenant", "client", "secret"), + transport=httpx.MockTransport(handler), + skew_seconds=0, + ) + + first = await provider.get_access_token() + second = await provider.get_access_token() + + assert first == "token-1" + assert second == "token-2" + assert len(calls) == 2 + + async def test_force_refresh_bypasses_cache(self): + calls: list[int] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + return httpx.Response( + 200, + json={ + "access_token": f"token-{len(calls)}", + "expires_in": 3600, + }, + ) + + provider = MicrosoftGraphTokenProvider( + GraphCredentials("tenant", "client", "secret"), + transport=httpx.MockTransport(handler), + ) + + first = await provider.get_access_token() + second = await provider.get_access_token(force_refresh=True) + + assert first == "token-1" + assert second == "token-2" + assert len(calls) == 2 + + async def test_invalid_token_response_raises(self): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"expires_in": 3600}) + + provider = MicrosoftGraphTokenProvider( + GraphCredentials("tenant", "client", "secret"), + transport=httpx.MockTransport(handler), + ) + + with pytest.raises(MicrosoftGraphTokenError) as exc: + await provider.get_access_token() + assert "access_token" in str(exc.value) + + async def test_http_error_includes_server_message(self): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={"error": "invalid_client", "error_description": "bad secret"}, + ) + + provider = MicrosoftGraphTokenProvider( + GraphCredentials("tenant", "client", "secret"), + transport=httpx.MockTransport(handler), + ) + + with pytest.raises(MicrosoftGraphTokenError) as exc: + await provider.get_access_token() + assert "bad secret" in str(exc.value) diff --git a/tests/tools/test_microsoft_graph_client.py b/tests/tools/test_microsoft_graph_client.py new file mode 100644 index 0000000000..e788856ff8 --- /dev/null +++ b/tests/tools/test_microsoft_graph_client.py @@ -0,0 +1,152 @@ +"""Tests for tools/microsoft_graph_client.py.""" + +from __future__ import annotations + +from pathlib import Path + +import httpx +import pytest + +from tools.microsoft_graph_auth import GraphCredentials, MicrosoftGraphTokenProvider +from tools.microsoft_graph_client import ( + MicrosoftGraphAPIError, + MicrosoftGraphClient, + MicrosoftGraphClientError, +) + + +def _make_provider() -> MicrosoftGraphTokenProvider: + provider = MicrosoftGraphTokenProvider(GraphCredentials("tenant", "client", "secret")) + provider._cached_token = type( # type: ignore[attr-defined] + "Token", + (), + { + "access_token": "cached-token", + "is_expired": lambda self, skew_seconds=0: False, + "expires_in_seconds": 3600, + }, + )() + return provider + + +@pytest.mark.anyio +class TestMicrosoftGraphClient: + async def test_attaches_bearer_token_header(self): + captured_auth: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured_auth.append(request.headers["Authorization"]) + return httpx.Response(200, json={"ok": True}) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + ) + payload = await client.get_json("/me") + assert payload == {"ok": True} + assert captured_auth == ["Bearer cached-token"] + + async def test_retries_on_rate_limit_and_uses_retry_after(self): + calls: list[int] = [] + sleeps: list[float] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + if len(calls) == 1: + return httpx.Response( + 429, + json={"error": {"code": "TooManyRequests", "message": "slow down"}}, + headers={"Retry-After": "3"}, + ) + return httpx.Response(200, json={"ok": True}) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + sleep=fake_sleep, + max_retries=2, + ) + + payload = await client.get_json("/me") + + assert payload == {"ok": True} + assert len(calls) == 2 + assert sleeps == [3.0] + + async def test_raises_api_error_after_retry_budget_exhausted(self): + sleeps: list[float] = [] + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(503, json={"error": {"message": "unavailable"}}) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + sleep=fake_sleep, + max_retries=1, + ) + + with pytest.raises(MicrosoftGraphAPIError) as exc: + await client.get_json("/me") + assert exc.value.status_code == 503 + assert sleeps == [0.5] + + async def test_collect_paginated_flattens_value_arrays(self): + def handler(request: httpx.Request) -> httpx.Response: + if str(request.url).endswith("/items"): + return httpx.Response( + 200, + json={ + "value": [{"id": "1"}], + "@odata.nextLink": "https://graph.microsoft.com/v1.0/items?page=2", + }, + ) + return httpx.Response(200, json={"value": [{"id": "2"}]}) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + ) + items = await client.collect_paginated("/items") + assert items == [{"id": "1"}, {"id": "2"}] + + async def test_download_to_file_writes_binary_content(self, tmp_path: Path): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + content=b"meeting-recording", + headers={"content-type": "video/mp4"}, + ) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + ) + destination = tmp_path / "recording.mp4" + result = await client.download_to_file("/drive/item/content", destination) + + assert destination.read_bytes() == b"meeting-recording" + assert result["content_type"] == "video/mp4" + assert result["size_bytes"] == len(b"meeting-recording") + + async def test_invalid_json_response_raises_client_error(self): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + content=b"not-json", + headers={"content-type": "application/json"}, + ) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + ) + + with pytest.raises(MicrosoftGraphClientError): + await client.get_json("/me") diff --git a/tools/microsoft_graph_auth.py b/tools/microsoft_graph_auth.py new file mode 100644 index 0000000000..46e3aa3875 --- /dev/null +++ b/tools/microsoft_graph_auth.py @@ -0,0 +1,245 @@ +"""Microsoft Graph app-only authentication helpers.""" + +from __future__ import annotations + +import asyncio +import os +import time +from dataclasses import dataclass +from typing import Any + +import httpx + + +DEFAULT_GRAPH_SCOPE = "https://graph.microsoft.com/.default" +DEFAULT_GRAPH_AUTHORITY_URL = "https://login.microsoftonline.com" +DEFAULT_TOKEN_SKEW_SECONDS = 120 + + +class MicrosoftGraphAuthError(RuntimeError): + """Base class for Microsoft Graph auth failures.""" + + +class MicrosoftGraphConfigError(MicrosoftGraphAuthError): + """Raised when Graph credentials are missing or invalid.""" + + +class MicrosoftGraphTokenError(MicrosoftGraphAuthError): + """Raised when token acquisition fails.""" + + +@dataclass(frozen=True) +class GraphCredentials: + """Normalized Microsoft Graph app-only credentials.""" + + tenant_id: str + client_id: str + client_secret: str + scope: str = DEFAULT_GRAPH_SCOPE + authority_url: str = DEFAULT_GRAPH_AUTHORITY_URL + + @property + def token_url(self) -> str: + base = self.authority_url.rstrip("/") + tenant = self.tenant_id.strip().strip("/") + return f"{base}/{tenant}/oauth2/v2.0/token" + + @classmethod + def from_env( + cls, + environ: dict[str, str] | None = None, + *, + required: bool = True, + ) -> "GraphCredentials | None": + env = environ if environ is not None else os.environ + tenant_id = (env.get("MSGRAPH_TENANT_ID") or "").strip() + client_id = (env.get("MSGRAPH_CLIENT_ID") or "").strip() + client_secret = (env.get("MSGRAPH_CLIENT_SECRET") or "").strip() + scope = (env.get("MSGRAPH_SCOPE") or DEFAULT_GRAPH_SCOPE).strip() + authority_url = ( + env.get("MSGRAPH_AUTHORITY_URL") or DEFAULT_GRAPH_AUTHORITY_URL + ).strip() + + missing = [ + name + for name, value in ( + ("MSGRAPH_TENANT_ID", tenant_id), + ("MSGRAPH_CLIENT_ID", client_id), + ("MSGRAPH_CLIENT_SECRET", client_secret), + ) + if not value + ] + if missing: + if not required: + return None + raise MicrosoftGraphConfigError( + f"Missing Microsoft Graph configuration: {', '.join(missing)}" + ) + + return cls( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + scope=scope, + authority_url=authority_url, + ) + + +@dataclass +class CachedAccessToken: + """Cached app-only Graph access token.""" + + access_token: str + expires_at: float + token_type: str = "Bearer" + + def is_expired(self, *, skew_seconds: int = DEFAULT_TOKEN_SKEW_SECONDS) -> bool: + return self.expires_at <= (time.time() + max(0, int(skew_seconds))) + + @property + def expires_in_seconds(self) -> int: + return max(0, int(self.expires_at - time.time())) + + +class MicrosoftGraphTokenProvider: + """Acquire and cache Microsoft Graph app-only access tokens.""" + + def __init__( + self, + credentials: GraphCredentials, + *, + timeout: float = 20.0, + skew_seconds: int = DEFAULT_TOKEN_SKEW_SECONDS, + transport: httpx.AsyncBaseTransport | None = None, + ) -> None: + self.credentials = credentials + self.timeout = timeout + self.skew_seconds = max(0, int(skew_seconds)) + self._transport = transport + self._cached_token: CachedAccessToken | None = None + self._lock = asyncio.Lock() + + @classmethod + def from_env( + cls, + environ: dict[str, str] | None = None, + **kwargs: Any, + ) -> "MicrosoftGraphTokenProvider": + credentials = GraphCredentials.from_env(environ) + return cls(credentials, **kwargs) + + def clear_cache(self) -> None: + self._cached_token = None + + def inspect_token_health(self) -> dict[str, Any]: + cached = self._cached_token + return { + "configured": True, + "tenant_id": self.credentials.tenant_id, + "client_id": self.credentials.client_id, + "scope": self.credentials.scope, + "authority_url": self.credentials.authority_url, + "token_url": self.credentials.token_url, + "cached": bool(cached), + "expires_in_seconds": cached.expires_in_seconds if cached else None, + "is_expired": cached.is_expired(skew_seconds=0) if cached else None, + "refresh_skew_seconds": self.skew_seconds, + } + + async def get_access_token(self, *, force_refresh: bool = False) -> str: + cached = self._cached_token + if not force_refresh and cached and not cached.is_expired( + skew_seconds=self.skew_seconds + ): + return cached.access_token + + async with self._lock: + cached = self._cached_token + if not force_refresh and cached and not cached.is_expired( + skew_seconds=self.skew_seconds + ): + return cached.access_token + + token = await self._fetch_access_token() + self._cached_token = token + return token.access_token + + async def _fetch_access_token(self) -> CachedAccessToken: + data = { + "grant_type": "client_credentials", + "client_id": self.credentials.client_id, + "client_secret": self.credentials.client_secret, + "scope": self.credentials.scope, + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + async with httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), + transport=self._transport, + ) as client: + response = await client.post( + self.credentials.token_url, + data=data, + headers=headers, + ) + + if response.status_code >= 400: + detail = _extract_error_detail(response) + raise MicrosoftGraphTokenError( + "Microsoft Graph token request failed with HTTP " + f"{response.status_code}: {detail}" + ) + + try: + payload = response.json() + except ValueError as exc: + raise MicrosoftGraphTokenError( + "Microsoft Graph token response was not valid JSON." + ) from exc + + access_token = str(payload.get("access_token") or "").strip() + token_type = str(payload.get("token_type") or "Bearer").strip() or "Bearer" + expires_in = payload.get("expires_in") + + if not access_token: + raise MicrosoftGraphTokenError( + "Microsoft Graph token response did not include access_token." + ) + + try: + expires_in_seconds = int(expires_in) + except (TypeError, ValueError) as exc: + raise MicrosoftGraphTokenError( + "Microsoft Graph token response did not include a valid expires_in." + ) from exc + + return CachedAccessToken( + access_token=access_token, + token_type=token_type, + expires_at=time.time() + max(0, expires_in_seconds), + ) + + +def _extract_error_detail(response: httpx.Response) -> str: + try: + payload = response.json() + except ValueError: + text = response.text.strip() + return text or "unknown error" + + if isinstance(payload, dict): + if isinstance(payload.get("error_description"), str): + return payload["error_description"] + error = payload.get("error") + if isinstance(error, dict): + message = error.get("message") + code = error.get("code") + if message and code: + return f"{code}: {message}" + if message: + return str(message) + if code: + return str(code) + if isinstance(error, str): + return error + return str(payload) diff --git a/tools/microsoft_graph_client.py b/tools/microsoft_graph_client.py new file mode 100644 index 0000000000..f92ba66c54 --- /dev/null +++ b/tools/microsoft_graph_client.py @@ -0,0 +1,327 @@ +"""Reusable Microsoft Graph REST client helpers.""" + +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +from typing import Any, AsyncIterator, Awaitable, Callable + +import httpx + +from tools.microsoft_graph_auth import GraphCredentials, MicrosoftGraphTokenProvider + + +DEFAULT_GRAPH_BASE_URL = "https://graph.microsoft.com/v1.0" + + +class MicrosoftGraphClientError(RuntimeError): + """Base class for Graph client failures.""" + + +class MicrosoftGraphAPIError(MicrosoftGraphClientError): + """Raised when a Graph API request fails.""" + + def __init__( + self, + status_code: int, + method: str, + url: str, + message: str, + *, + retry_after_seconds: float | None = None, + payload: Any = None, + ) -> None: + self.status_code = status_code + self.method = method + self.url = url + self.retry_after_seconds = retry_after_seconds + self.payload = payload + super().__init__( + f"Microsoft Graph API error {status_code} for {method} {url}: {message}" + ) + + +class MicrosoftGraphClient: + """Minimal async Microsoft Graph client with retries and pagination.""" + + def __init__( + self, + token_provider: MicrosoftGraphTokenProvider, + *, + base_url: str = DEFAULT_GRAPH_BASE_URL, + timeout: float = 60.0, + max_retries: int = 3, + transport: httpx.AsyncBaseTransport | None = None, + sleep: Callable[[float], Awaitable[None]] | None = None, + user_agent: str = "Hermes-Agent/graph-client", + ) -> None: + self.token_provider = token_provider + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.max_retries = max(0, int(max_retries)) + self._transport = transport + self._sleep = sleep or asyncio.sleep + self.user_agent = user_agent + + @classmethod + def from_env(cls, **kwargs: Any) -> "MicrosoftGraphClient": + credentials = GraphCredentials.from_env() + provider = MicrosoftGraphTokenProvider(credentials) + return cls(provider, **kwargs) + + async def get_json( + self, + path: str, + *, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> Any: + response = await self._request("GET", path, params=params, headers=headers) + return self._decode_json(response) + + async def post_json( + self, + path: str, + *, + json_body: Any | None = None, + headers: dict[str, str] | None = None, + ) -> Any: + response = await self._request("POST", path, json_body=json_body, headers=headers) + return self._decode_json(response) + + async def patch_json( + self, + path: str, + *, + json_body: Any | None = None, + headers: dict[str, str] | None = None, + ) -> Any: + response = await self._request("PATCH", path, json_body=json_body, headers=headers) + if response.status_code == 204 or not response.content: + return {} + return self._decode_json(response) + + async def delete( + self, + path: str, + *, + headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + response = await self._request("DELETE", path, headers=headers) + if response.status_code == 204 or not response.content: + return {"deleted": True, "status_code": response.status_code} + return self._decode_json(response) + + async def iterate_pages( + self, + path: str, + *, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> AsyncIterator[dict[str, Any]]: + next_url: str | None = self._resolve_url(path) + next_params = dict(params or {}) + while next_url: + response = await self._request( + "GET", + next_url, + params=next_params or None, + headers=headers, + ) + payload = self._decode_json(response) + if not isinstance(payload, dict): + raise MicrosoftGraphClientError( + f"Expected paginated Graph response dict, got {type(payload).__name__}." + ) + yield payload + next_url = payload.get("@odata.nextLink") + next_params = {} + + async def collect_paginated( + self, + path: str, + *, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> list[Any]: + items: list[Any] = [] + async for page in self.iterate_pages(path, params=params, headers=headers): + value = page.get("value") + if isinstance(value, list): + items.extend(value) + return items + + async def download_to_file( + self, + path: str, + destination: str | Path, + *, + headers: dict[str, str] | None = None, + chunk_size: int = 65536, + ) -> dict[str, Any]: + response = await self._request("GET", path, headers=headers) + target = Path(destination) + target.parent.mkdir(parents=True, exist_ok=True) + tmp_target = target.with_suffix(target.suffix + ".part") + with tmp_target.open("wb") as handle: + async for chunk in response.aiter_bytes(chunk_size=chunk_size): + if chunk: + handle.write(chunk) + os.replace(tmp_target, target) + return { + "path": str(target), + "size_bytes": target.stat().st_size, + "content_type": response.headers.get("content-type"), + } + + async def _request( + self, + method: str, + path_or_url: str, + *, + params: dict[str, Any] | None = None, + json_body: Any | None = None, + headers: dict[str, str] | None = None, + ) -> httpx.Response: + url = self._resolve_url(path_or_url) + attempt = 0 + last_error: Exception | None = None + + while attempt <= self.max_retries: + token = await self.token_provider.get_access_token( + force_refresh=attempt > 0 and self._should_refresh_token(last_error) + ) + request_headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/json", + "User-Agent": self.user_agent, + } + if json_body is not None: + request_headers["Content-Type"] = "application/json" + if headers: + request_headers.update(headers) + + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), + transport=self._transport, + ) as client: + response = await client.request( + method, + url, + params=params, + json=json_body, + headers=request_headers, + ) + except httpx.HTTPError as exc: + last_error = exc + if attempt >= self.max_retries: + raise MicrosoftGraphClientError( + f"Microsoft Graph request failed for {method} {url}: {exc}" + ) from exc + await self._sleep(self._retry_delay(None, attempt)) + attempt += 1 + continue + + if response.status_code < 400: + return response + + api_error = self._build_api_error(method, url, response) + last_error = api_error + + if response.status_code == 401 and attempt < self.max_retries: + self.token_provider.clear_cache() + await self._sleep(self._retry_delay(response, attempt)) + attempt += 1 + continue + + if self._should_retry(response) and attempt < self.max_retries: + await self._sleep(self._retry_delay(response, attempt)) + attempt += 1 + continue + + raise api_error + + raise MicrosoftGraphClientError( + f"Microsoft Graph request exhausted retries for {method} {url}." + ) + + def _resolve_url(self, path_or_url: str) -> str: + if path_or_url.startswith(("http://", "https://")): + return path_or_url + path = path_or_url if path_or_url.startswith("/") else f"/{path_or_url}" + return f"{self.base_url}{path}" + + @staticmethod + def _decode_json(response: httpx.Response) -> Any: + try: + return response.json() + except ValueError as exc: + raise MicrosoftGraphClientError( + "Microsoft Graph response was not valid JSON for " + f"{response.request.method} {response.request.url}" + ) from exc + + @staticmethod + def _should_retry(response: httpx.Response | None) -> bool: + if response is None: + return True + return response.status_code == 429 or 500 <= response.status_code < 600 + + @staticmethod + def _should_refresh_token(error: Exception | None) -> bool: + return isinstance(error, MicrosoftGraphAPIError) and error.status_code == 401 + + @staticmethod + def _retry_delay(response: httpx.Response | None, attempt: int) -> float: + if response is not None: + retry_after = response.headers.get("Retry-After") + if retry_after: + try: + return max(0.0, float(retry_after)) + except ValueError: + pass + return min(8.0, 0.5 * (2 ** attempt)) + + @staticmethod + def _build_api_error( + method: str, + url: str, + response: httpx.Response, + ) -> MicrosoftGraphAPIError: + payload: Any = None + message = response.text.strip() or "unknown error" + try: + payload = response.json() + except ValueError: + payload = None + + if isinstance(payload, dict): + error = payload.get("error") + if isinstance(error, dict): + code = error.get("code") + inner_message = error.get("message") + if code and inner_message: + message = f"{code}: {inner_message}" + elif inner_message: + message = str(inner_message) + elif isinstance(error, str): + message = error + + retry_after: float | None = None + header_value = response.headers.get("Retry-After") + if header_value: + try: + retry_after = float(header_value) + except ValueError: + retry_after = None + + return MicrosoftGraphAPIError( + response.status_code, + method, + url, + message, + retry_after_seconds=retry_after, + payload=payload, + )