mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-12 03:42:08 +00:00
feat(msgraph): add auth and client foundation
This commit is contained in:
parent
ea8e608821
commit
a152c706b7
4 changed files with 873 additions and 0 deletions
149
tests/tools/test_microsoft_graph_auth.py
Normal file
149
tests/tools/test_microsoft_graph_auth.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""Tests for tools/microsoft_graph_auth.py."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.microsoft_graph_auth import (
|
||||||
|
DEFAULT_GRAPH_SCOPE,
|
||||||
|
GraphCredentials,
|
||||||
|
MicrosoftGraphConfigError,
|
||||||
|
MicrosoftGraphTokenError,
|
||||||
|
MicrosoftGraphTokenProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphCredentials:
|
||||||
|
def test_from_env_raises_for_missing_required_values(self):
|
||||||
|
with pytest.raises(MicrosoftGraphConfigError) as exc:
|
||||||
|
GraphCredentials.from_env({})
|
||||||
|
assert "MSGRAPH_TENANT_ID" in str(exc.value)
|
||||||
|
assert "MSGRAPH_CLIENT_ID" in str(exc.value)
|
||||||
|
assert "MSGRAPH_CLIENT_SECRET" in str(exc.value)
|
||||||
|
|
||||||
|
def test_from_env_optional_returns_none_when_not_configured(self):
|
||||||
|
assert GraphCredentials.from_env({}, required=False) is None
|
||||||
|
|
||||||
|
def test_from_env_builds_normalized_credentials(self):
|
||||||
|
creds = GraphCredentials.from_env(
|
||||||
|
{
|
||||||
|
"MSGRAPH_TENANT_ID": "tenant-123",
|
||||||
|
"MSGRAPH_CLIENT_ID": "client-456",
|
||||||
|
"MSGRAPH_CLIENT_SECRET": "secret-789",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert creds is not None
|
||||||
|
assert creds.scope == DEFAULT_GRAPH_SCOPE
|
||||||
|
assert creds.token_url.endswith("/tenant-123/oauth2/v2.0/token")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
class TestMicrosoftGraphTokenProvider:
|
||||||
|
async def test_reuses_cached_token_until_expiry(self):
|
||||||
|
calls: list[int] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": f"token-{len(calls)}",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = MicrosoftGraphTokenProvider(
|
||||||
|
GraphCredentials("tenant", "client", "secret"),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
first = await provider.get_access_token()
|
||||||
|
second = await provider.get_access_token()
|
||||||
|
|
||||||
|
assert first == "token-1"
|
||||||
|
assert second == "token-1"
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
async def test_refreshes_when_cached_token_is_expired(self):
|
||||||
|
calls: list[int] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
expires_in = 0 if len(calls) == 1 else 3600
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": f"token-{len(calls)}",
|
||||||
|
"expires_in": expires_in,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = MicrosoftGraphTokenProvider(
|
||||||
|
GraphCredentials("tenant", "client", "secret"),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
skew_seconds=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
first = await provider.get_access_token()
|
||||||
|
second = await provider.get_access_token()
|
||||||
|
|
||||||
|
assert first == "token-1"
|
||||||
|
assert second == "token-2"
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
async def test_force_refresh_bypasses_cache(self):
|
||||||
|
calls: list[int] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"access_token": f"token-{len(calls)}",
|
||||||
|
"expires_in": 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = MicrosoftGraphTokenProvider(
|
||||||
|
GraphCredentials("tenant", "client", "secret"),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
first = await provider.get_access_token()
|
||||||
|
second = await provider.get_access_token(force_refresh=True)
|
||||||
|
|
||||||
|
assert first == "token-1"
|
||||||
|
assert second == "token-2"
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
async def test_invalid_token_response_raises(self):
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, json={"expires_in": 3600})
|
||||||
|
|
||||||
|
provider = MicrosoftGraphTokenProvider(
|
||||||
|
GraphCredentials("tenant", "client", "secret"),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(MicrosoftGraphTokenError) as exc:
|
||||||
|
await provider.get_access_token()
|
||||||
|
assert "access_token" in str(exc.value)
|
||||||
|
|
||||||
|
async def test_http_error_includes_server_message(self):
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
401,
|
||||||
|
json={"error": "invalid_client", "error_description": "bad secret"},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = MicrosoftGraphTokenProvider(
|
||||||
|
GraphCredentials("tenant", "client", "secret"),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(MicrosoftGraphTokenError) as exc:
|
||||||
|
await provider.get_access_token()
|
||||||
|
assert "bad secret" in str(exc.value)
|
||||||
152
tests/tools/test_microsoft_graph_client.py
Normal file
152
tests/tools/test_microsoft_graph_client.py
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
"""Tests for tools/microsoft_graph_client.py."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tools.microsoft_graph_auth import GraphCredentials, MicrosoftGraphTokenProvider
|
||||||
|
from tools.microsoft_graph_client import (
|
||||||
|
MicrosoftGraphAPIError,
|
||||||
|
MicrosoftGraphClient,
|
||||||
|
MicrosoftGraphClientError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_provider() -> MicrosoftGraphTokenProvider:
|
||||||
|
provider = MicrosoftGraphTokenProvider(GraphCredentials("tenant", "client", "secret"))
|
||||||
|
provider._cached_token = type( # type: ignore[attr-defined]
|
||||||
|
"Token",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"access_token": "cached-token",
|
||||||
|
"is_expired": lambda self, skew_seconds=0: False,
|
||||||
|
"expires_in_seconds": 3600,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
class TestMicrosoftGraphClient:
|
||||||
|
async def test_attaches_bearer_token_header(self):
|
||||||
|
captured_auth: list[str] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured_auth.append(request.headers["Authorization"])
|
||||||
|
return httpx.Response(200, json={"ok": True})
|
||||||
|
|
||||||
|
client = MicrosoftGraphClient(
|
||||||
|
_make_provider(),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
payload = await client.get_json("/me")
|
||||||
|
assert payload == {"ok": True}
|
||||||
|
assert captured_auth == ["Bearer cached-token"]
|
||||||
|
|
||||||
|
async def test_retries_on_rate_limit_and_uses_retry_after(self):
|
||||||
|
calls: list[int] = []
|
||||||
|
sleeps: list[float] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
calls.append(1)
|
||||||
|
if len(calls) == 1:
|
||||||
|
return httpx.Response(
|
||||||
|
429,
|
||||||
|
json={"error": {"code": "TooManyRequests", "message": "slow down"}},
|
||||||
|
headers={"Retry-After": "3"},
|
||||||
|
)
|
||||||
|
return httpx.Response(200, json={"ok": True})
|
||||||
|
|
||||||
|
async def fake_sleep(delay: float) -> None:
|
||||||
|
sleeps.append(delay)
|
||||||
|
|
||||||
|
client = MicrosoftGraphClient(
|
||||||
|
_make_provider(),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
sleep=fake_sleep,
|
||||||
|
max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await client.get_json("/me")
|
||||||
|
|
||||||
|
assert payload == {"ok": True}
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert sleeps == [3.0]
|
||||||
|
|
||||||
|
async def test_raises_api_error_after_retry_budget_exhausted(self):
|
||||||
|
sleeps: list[float] = []
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(503, json={"error": {"message": "unavailable"}})
|
||||||
|
|
||||||
|
async def fake_sleep(delay: float) -> None:
|
||||||
|
sleeps.append(delay)
|
||||||
|
|
||||||
|
client = MicrosoftGraphClient(
|
||||||
|
_make_provider(),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
sleep=fake_sleep,
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(MicrosoftGraphAPIError) as exc:
|
||||||
|
await client.get_json("/me")
|
||||||
|
assert exc.value.status_code == 503
|
||||||
|
assert sleeps == [0.5]
|
||||||
|
|
||||||
|
async def test_collect_paginated_flattens_value_arrays(self):
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
if str(request.url).endswith("/items"):
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"value": [{"id": "1"}],
|
||||||
|
"@odata.nextLink": "https://graph.microsoft.com/v1.0/items?page=2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return httpx.Response(200, json={"value": [{"id": "2"}]})
|
||||||
|
|
||||||
|
client = MicrosoftGraphClient(
|
||||||
|
_make_provider(),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
items = await client.collect_paginated("/items")
|
||||||
|
assert items == [{"id": "1"}, {"id": "2"}]
|
||||||
|
|
||||||
|
async def test_download_to_file_writes_binary_content(self, tmp_path: Path):
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
content=b"meeting-recording",
|
||||||
|
headers={"content-type": "video/mp4"},
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MicrosoftGraphClient(
|
||||||
|
_make_provider(),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
destination = tmp_path / "recording.mp4"
|
||||||
|
result = await client.download_to_file("/drive/item/content", destination)
|
||||||
|
|
||||||
|
assert destination.read_bytes() == b"meeting-recording"
|
||||||
|
assert result["content_type"] == "video/mp4"
|
||||||
|
assert result["size_bytes"] == len(b"meeting-recording")
|
||||||
|
|
||||||
|
async def test_invalid_json_response_raises_client_error(self):
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
content=b"not-json",
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MicrosoftGraphClient(
|
||||||
|
_make_provider(),
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(MicrosoftGraphClientError):
|
||||||
|
await client.get_json("/me")
|
||||||
245
tools/microsoft_graph_auth.py
Normal file
245
tools/microsoft_graph_auth.py
Normal file
|
|
@ -0,0 +1,245 @@
|
||||||
|
"""Microsoft Graph app-only authentication helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_GRAPH_SCOPE = "https://graph.microsoft.com/.default"
|
||||||
|
DEFAULT_GRAPH_AUTHORITY_URL = "https://login.microsoftonline.com"
|
||||||
|
DEFAULT_TOKEN_SKEW_SECONDS = 120
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphAuthError(RuntimeError):
|
||||||
|
"""Base class for Microsoft Graph auth failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphConfigError(MicrosoftGraphAuthError):
|
||||||
|
"""Raised when Graph credentials are missing or invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphTokenError(MicrosoftGraphAuthError):
|
||||||
|
"""Raised when token acquisition fails."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GraphCredentials:
|
||||||
|
"""Normalized Microsoft Graph app-only credentials."""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
scope: str = DEFAULT_GRAPH_SCOPE
|
||||||
|
authority_url: str = DEFAULT_GRAPH_AUTHORITY_URL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_url(self) -> str:
|
||||||
|
base = self.authority_url.rstrip("/")
|
||||||
|
tenant = self.tenant_id.strip().strip("/")
|
||||||
|
return f"{base}/{tenant}/oauth2/v2.0/token"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(
|
||||||
|
cls,
|
||||||
|
environ: dict[str, str] | None = None,
|
||||||
|
*,
|
||||||
|
required: bool = True,
|
||||||
|
) -> "GraphCredentials | None":
|
||||||
|
env = environ if environ is not None else os.environ
|
||||||
|
tenant_id = (env.get("MSGRAPH_TENANT_ID") or "").strip()
|
||||||
|
client_id = (env.get("MSGRAPH_CLIENT_ID") or "").strip()
|
||||||
|
client_secret = (env.get("MSGRAPH_CLIENT_SECRET") or "").strip()
|
||||||
|
scope = (env.get("MSGRAPH_SCOPE") or DEFAULT_GRAPH_SCOPE).strip()
|
||||||
|
authority_url = (
|
||||||
|
env.get("MSGRAPH_AUTHORITY_URL") or DEFAULT_GRAPH_AUTHORITY_URL
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
missing = [
|
||||||
|
name
|
||||||
|
for name, value in (
|
||||||
|
("MSGRAPH_TENANT_ID", tenant_id),
|
||||||
|
("MSGRAPH_CLIENT_ID", client_id),
|
||||||
|
("MSGRAPH_CLIENT_SECRET", client_secret),
|
||||||
|
)
|
||||||
|
if not value
|
||||||
|
]
|
||||||
|
if missing:
|
||||||
|
if not required:
|
||||||
|
return None
|
||||||
|
raise MicrosoftGraphConfigError(
|
||||||
|
f"Missing Microsoft Graph configuration: {', '.join(missing)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
scope=scope,
|
||||||
|
authority_url=authority_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedAccessToken:
|
||||||
|
"""Cached app-only Graph access token."""
|
||||||
|
|
||||||
|
access_token: str
|
||||||
|
expires_at: float
|
||||||
|
token_type: str = "Bearer"
|
||||||
|
|
||||||
|
def is_expired(self, *, skew_seconds: int = DEFAULT_TOKEN_SKEW_SECONDS) -> bool:
|
||||||
|
return self.expires_at <= (time.time() + max(0, int(skew_seconds)))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expires_in_seconds(self) -> int:
|
||||||
|
return max(0, int(self.expires_at - time.time()))
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphTokenProvider:
|
||||||
|
"""Acquire and cache Microsoft Graph app-only access tokens."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
credentials: GraphCredentials,
|
||||||
|
*,
|
||||||
|
timeout: float = 20.0,
|
||||||
|
skew_seconds: int = DEFAULT_TOKEN_SKEW_SECONDS,
|
||||||
|
transport: httpx.AsyncBaseTransport | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.credentials = credentials
|
||||||
|
self.timeout = timeout
|
||||||
|
self.skew_seconds = max(0, int(skew_seconds))
|
||||||
|
self._transport = transport
|
||||||
|
self._cached_token: CachedAccessToken | None = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(
|
||||||
|
cls,
|
||||||
|
environ: dict[str, str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "MicrosoftGraphTokenProvider":
|
||||||
|
credentials = GraphCredentials.from_env(environ)
|
||||||
|
return cls(credentials, **kwargs)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
self._cached_token = None
|
||||||
|
|
||||||
|
def inspect_token_health(self) -> dict[str, Any]:
|
||||||
|
cached = self._cached_token
|
||||||
|
return {
|
||||||
|
"configured": True,
|
||||||
|
"tenant_id": self.credentials.tenant_id,
|
||||||
|
"client_id": self.credentials.client_id,
|
||||||
|
"scope": self.credentials.scope,
|
||||||
|
"authority_url": self.credentials.authority_url,
|
||||||
|
"token_url": self.credentials.token_url,
|
||||||
|
"cached": bool(cached),
|
||||||
|
"expires_in_seconds": cached.expires_in_seconds if cached else None,
|
||||||
|
"is_expired": cached.is_expired(skew_seconds=0) if cached else None,
|
||||||
|
"refresh_skew_seconds": self.skew_seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_access_token(self, *, force_refresh: bool = False) -> str:
|
||||||
|
cached = self._cached_token
|
||||||
|
if not force_refresh and cached and not cached.is_expired(
|
||||||
|
skew_seconds=self.skew_seconds
|
||||||
|
):
|
||||||
|
return cached.access_token
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
cached = self._cached_token
|
||||||
|
if not force_refresh and cached and not cached.is_expired(
|
||||||
|
skew_seconds=self.skew_seconds
|
||||||
|
):
|
||||||
|
return cached.access_token
|
||||||
|
|
||||||
|
token = await self._fetch_access_token()
|
||||||
|
self._cached_token = token
|
||||||
|
return token.access_token
|
||||||
|
|
||||||
|
async def _fetch_access_token(self) -> CachedAccessToken:
|
||||||
|
data = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"client_id": self.credentials.client_id,
|
||||||
|
"client_secret": self.credentials.client_secret,
|
||||||
|
"scope": self.credentials.scope,
|
||||||
|
}
|
||||||
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(self.timeout),
|
||||||
|
transport=self._transport,
|
||||||
|
) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.credentials.token_url,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code >= 400:
|
||||||
|
detail = _extract_error_detail(response)
|
||||||
|
raise MicrosoftGraphTokenError(
|
||||||
|
"Microsoft Graph token request failed with HTTP "
|
||||||
|
f"{response.status_code}: {detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
except ValueError as exc:
|
||||||
|
raise MicrosoftGraphTokenError(
|
||||||
|
"Microsoft Graph token response was not valid JSON."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
access_token = str(payload.get("access_token") or "").strip()
|
||||||
|
token_type = str(payload.get("token_type") or "Bearer").strip() or "Bearer"
|
||||||
|
expires_in = payload.get("expires_in")
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
raise MicrosoftGraphTokenError(
|
||||||
|
"Microsoft Graph token response did not include access_token."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
expires_in_seconds = int(expires_in)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise MicrosoftGraphTokenError(
|
||||||
|
"Microsoft Graph token response did not include a valid expires_in."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return CachedAccessToken(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type=token_type,
|
||||||
|
expires_at=time.time() + max(0, expires_in_seconds),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_error_detail(response: httpx.Response) -> str:
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
except ValueError:
|
||||||
|
text = response.text.strip()
|
||||||
|
return text or "unknown error"
|
||||||
|
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
if isinstance(payload.get("error_description"), str):
|
||||||
|
return payload["error_description"]
|
||||||
|
error = payload.get("error")
|
||||||
|
if isinstance(error, dict):
|
||||||
|
message = error.get("message")
|
||||||
|
code = error.get("code")
|
||||||
|
if message and code:
|
||||||
|
return f"{code}: {message}"
|
||||||
|
if message:
|
||||||
|
return str(message)
|
||||||
|
if code:
|
||||||
|
return str(code)
|
||||||
|
if isinstance(error, str):
|
||||||
|
return error
|
||||||
|
return str(payload)
|
||||||
327
tools/microsoft_graph_client.py
Normal file
327
tools/microsoft_graph_client.py
Normal file
|
|
@ -0,0 +1,327 @@
|
||||||
|
"""Reusable Microsoft Graph REST client helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, AsyncIterator, Awaitable, Callable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from tools.microsoft_graph_auth import GraphCredentials, MicrosoftGraphTokenProvider
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_GRAPH_BASE_URL = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphClientError(RuntimeError):
|
||||||
|
"""Base class for Graph client failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphAPIError(MicrosoftGraphClientError):
|
||||||
|
"""Raised when a Graph API request fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
retry_after_seconds: float | None = None,
|
||||||
|
payload: Any = None,
|
||||||
|
) -> None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self.method = method
|
||||||
|
self.url = url
|
||||||
|
self.retry_after_seconds = retry_after_seconds
|
||||||
|
self.payload = payload
|
||||||
|
super().__init__(
|
||||||
|
f"Microsoft Graph API error {status_code} for {method} {url}: {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MicrosoftGraphClient:
|
||||||
|
"""Minimal async Microsoft Graph client with retries and pagination."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token_provider: MicrosoftGraphTokenProvider,
|
||||||
|
*,
|
||||||
|
base_url: str = DEFAULT_GRAPH_BASE_URL,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
max_retries: int = 3,
|
||||||
|
transport: httpx.AsyncBaseTransport | None = None,
|
||||||
|
sleep: Callable[[float], Awaitable[None]] | None = None,
|
||||||
|
user_agent: str = "Hermes-Agent/graph-client",
|
||||||
|
) -> None:
|
||||||
|
self.token_provider = token_provider
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max(0, int(max_retries))
|
||||||
|
self._transport = transport
|
||||||
|
self._sleep = sleep or asyncio.sleep
|
||||||
|
self.user_agent = user_agent
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls, **kwargs: Any) -> "MicrosoftGraphClient":
|
||||||
|
credentials = GraphCredentials.from_env()
|
||||||
|
provider = MicrosoftGraphTokenProvider(credentials)
|
||||||
|
return cls(provider, **kwargs)
|
||||||
|
|
||||||
|
async def get_json(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
response = await self._request("GET", path, params=params, headers=headers)
|
||||||
|
return self._decode_json(response)
|
||||||
|
|
||||||
|
async def post_json(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
json_body: Any | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
response = await self._request("POST", path, json_body=json_body, headers=headers)
|
||||||
|
return self._decode_json(response)
|
||||||
|
|
||||||
|
async def patch_json(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
json_body: Any | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
response = await self._request("PATCH", path, json_body=json_body, headers=headers)
|
||||||
|
if response.status_code == 204 or not response.content:
|
||||||
|
return {}
|
||||||
|
return self._decode_json(response)
|
||||||
|
|
||||||
|
async def delete(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
response = await self._request("DELETE", path, headers=headers)
|
||||||
|
if response.status_code == 204 or not response.content:
|
||||||
|
return {"deleted": True, "status_code": response.status_code}
|
||||||
|
return self._decode_json(response)
|
||||||
|
|
||||||
|
async def iterate_pages(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
|
next_url: str | None = self._resolve_url(path)
|
||||||
|
next_params = dict(params or {})
|
||||||
|
while next_url:
|
||||||
|
response = await self._request(
|
||||||
|
"GET",
|
||||||
|
next_url,
|
||||||
|
params=next_params or None,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
payload = self._decode_json(response)
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise MicrosoftGraphClientError(
|
||||||
|
f"Expected paginated Graph response dict, got {type(payload).__name__}."
|
||||||
|
)
|
||||||
|
yield payload
|
||||||
|
next_url = payload.get("@odata.nextLink")
|
||||||
|
next_params = {}
|
||||||
|
|
||||||
|
async def collect_paginated(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
items: list[Any] = []
|
||||||
|
async for page in self.iterate_pages(path, params=params, headers=headers):
|
||||||
|
value = page.get("value")
|
||||||
|
if isinstance(value, list):
|
||||||
|
items.extend(value)
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def download_to_file(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
destination: str | Path,
|
||||||
|
*,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
chunk_size: int = 65536,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
response = await self._request("GET", path, headers=headers)
|
||||||
|
target = Path(destination)
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_target = target.with_suffix(target.suffix + ".part")
|
||||||
|
with tmp_target.open("wb") as handle:
|
||||||
|
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
|
||||||
|
if chunk:
|
||||||
|
handle.write(chunk)
|
||||||
|
os.replace(tmp_target, target)
|
||||||
|
return {
|
||||||
|
"path": str(target),
|
||||||
|
"size_bytes": target.stat().st_size,
|
||||||
|
"content_type": response.headers.get("content-type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path_or_url: str,
|
||||||
|
*,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
json_body: Any | None = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
url = self._resolve_url(path_or_url)
|
||||||
|
attempt = 0
|
||||||
|
last_error: Exception | None = None
|
||||||
|
|
||||||
|
while attempt <= self.max_retries:
|
||||||
|
token = await self.token_provider.get_access_token(
|
||||||
|
force_refresh=attempt > 0 and self._should_refresh_token(last_error)
|
||||||
|
)
|
||||||
|
request_headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"User-Agent": self.user_agent,
|
||||||
|
}
|
||||||
|
if json_body is not None:
|
||||||
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
if headers:
|
||||||
|
request_headers.update(headers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(self.timeout),
|
||||||
|
transport=self._transport,
|
||||||
|
) as client:
|
||||||
|
response = await client.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
json=json_body,
|
||||||
|
headers=request_headers,
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
last_error = exc
|
||||||
|
if attempt >= self.max_retries:
|
||||||
|
raise MicrosoftGraphClientError(
|
||||||
|
f"Microsoft Graph request failed for {method} {url}: {exc}"
|
||||||
|
) from exc
|
||||||
|
await self._sleep(self._retry_delay(None, attempt))
|
||||||
|
attempt += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.status_code < 400:
|
||||||
|
return response
|
||||||
|
|
||||||
|
api_error = self._build_api_error(method, url, response)
|
||||||
|
last_error = api_error
|
||||||
|
|
||||||
|
if response.status_code == 401 and attempt < self.max_retries:
|
||||||
|
self.token_provider.clear_cache()
|
||||||
|
await self._sleep(self._retry_delay(response, attempt))
|
||||||
|
attempt += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._should_retry(response) and attempt < self.max_retries:
|
||||||
|
await self._sleep(self._retry_delay(response, attempt))
|
||||||
|
attempt += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise api_error
|
||||||
|
|
||||||
|
raise MicrosoftGraphClientError(
|
||||||
|
f"Microsoft Graph request exhausted retries for {method} {url}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_url(self, path_or_url: str) -> str:
|
||||||
|
if path_or_url.startswith(("http://", "https://")):
|
||||||
|
return path_or_url
|
||||||
|
path = path_or_url if path_or_url.startswith("/") else f"/{path_or_url}"
|
||||||
|
return f"{self.base_url}{path}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_json(response: httpx.Response) -> Any:
|
||||||
|
try:
|
||||||
|
return response.json()
|
||||||
|
except ValueError as exc:
|
||||||
|
raise MicrosoftGraphClientError(
|
||||||
|
"Microsoft Graph response was not valid JSON for "
|
||||||
|
f"{response.request.method} {response.request.url}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_retry(response: httpx.Response | None) -> bool:
|
||||||
|
if response is None:
|
||||||
|
return True
|
||||||
|
return response.status_code == 429 or 500 <= response.status_code < 600
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_refresh_token(error: Exception | None) -> bool:
|
||||||
|
return isinstance(error, MicrosoftGraphAPIError) and error.status_code == 401
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _retry_delay(response: httpx.Response | None, attempt: int) -> float:
|
||||||
|
if response is not None:
|
||||||
|
retry_after = response.headers.get("Retry-After")
|
||||||
|
if retry_after:
|
||||||
|
try:
|
||||||
|
return max(0.0, float(retry_after))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return min(8.0, 0.5 * (2 ** attempt))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_api_error(
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
response: httpx.Response,
|
||||||
|
) -> MicrosoftGraphAPIError:
|
||||||
|
payload: Any = None
|
||||||
|
message = response.text.strip() or "unknown error"
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
except ValueError:
|
||||||
|
payload = None
|
||||||
|
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
error = payload.get("error")
|
||||||
|
if isinstance(error, dict):
|
||||||
|
code = error.get("code")
|
||||||
|
inner_message = error.get("message")
|
||||||
|
if code and inner_message:
|
||||||
|
message = f"{code}: {inner_message}"
|
||||||
|
elif inner_message:
|
||||||
|
message = str(inner_message)
|
||||||
|
elif isinstance(error, str):
|
||||||
|
message = error
|
||||||
|
|
||||||
|
retry_after: float | None = None
|
||||||
|
header_value = response.headers.get("Retry-After")
|
||||||
|
if header_value:
|
||||||
|
try:
|
||||||
|
retry_after = float(header_value)
|
||||||
|
except ValueError:
|
||||||
|
retry_after = None
|
||||||
|
|
||||||
|
return MicrosoftGraphAPIError(
|
||||||
|
response.status_code,
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
message,
|
||||||
|
retry_after_seconds=retry_after,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue