hermes-agent/tools/microsoft_graph_client.py
Teknium 45d860d424 fix(msgraph): stream download_to_file body instead of buffering
The prior implementation routed download_to_file through the shared
_request() path, which uses httpx.AsyncClient.request() inside a
context manager that closes before aiter_bytes() iterates. The body
was read into memory first and the chunked write loop replayed it
from buffer. On small test payloads this was invisible; on real
Teams meeting recordings (hundreds of MB) it would force the full
artifact into RAM per download.

Rewrites download_to_file to open its own AsyncClient and use
client.stream(), keeping the context open across the aiter_bytes
iteration so the body is actually streamed chunk-by-chunk to disk.
Retry/token-refresh/Retry-After semantics are preserved by handling
them inline on the stream path. Partial .part files are cleaned up
on transport errors and on exhausted retries.

Adds three tests: large-payload streaming verifies the chunk loop
runs multiple times (discriminator: 512 KiB at chunk_size=65536
yields 8 chunks under streaming, 1 under buffering), transient-5xx
retry recovers after a single retry, and exhausted-retry cleans up
the partial file.
2026-05-08 09:27:26 -07:00

408 lines
14 KiB
Python

"""Reusable Microsoft Graph REST client helpers."""
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from typing import Any, AsyncIterator, Awaitable, Callable
import httpx
from tools.microsoft_graph_auth import GraphCredentials, MicrosoftGraphTokenProvider
DEFAULT_GRAPH_BASE_URL = "https://graph.microsoft.com/v1.0"
class MicrosoftGraphClientError(RuntimeError):
"""Base class for Graph client failures."""
class MicrosoftGraphAPIError(MicrosoftGraphClientError):
"""Raised when a Graph API request fails."""
def __init__(
self,
status_code: int,
method: str,
url: str,
message: str,
*,
retry_after_seconds: float | None = None,
payload: Any = None,
) -> None:
self.status_code = status_code
self.method = method
self.url = url
self.retry_after_seconds = retry_after_seconds
self.payload = payload
super().__init__(
f"Microsoft Graph API error {status_code} for {method} {url}: {message}"
)
class MicrosoftGraphClient:
"""Minimal async Microsoft Graph client with retries and pagination."""
def __init__(
self,
token_provider: MicrosoftGraphTokenProvider,
*,
base_url: str = DEFAULT_GRAPH_BASE_URL,
timeout: float = 60.0,
max_retries: int = 3,
transport: httpx.AsyncBaseTransport | None = None,
sleep: Callable[[float], Awaitable[None]] | None = None,
user_agent: str = "Hermes-Agent/graph-client",
) -> None:
self.token_provider = token_provider
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max(0, int(max_retries))
self._transport = transport
self._sleep = sleep or asyncio.sleep
self.user_agent = user_agent
@classmethod
def from_env(cls, **kwargs: Any) -> "MicrosoftGraphClient":
credentials = GraphCredentials.from_env()
provider = MicrosoftGraphTokenProvider(credentials)
return cls(provider, **kwargs)
async def get_json(
self,
path: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
response = await self._request("GET", path, params=params, headers=headers)
return self._decode_json(response)
async def post_json(
self,
path: str,
*,
json_body: Any | None = None,
headers: dict[str, str] | None = None,
) -> Any:
response = await self._request("POST", path, json_body=json_body, headers=headers)
return self._decode_json(response)
async def patch_json(
self,
path: str,
*,
json_body: Any | None = None,
headers: dict[str, str] | None = None,
) -> Any:
response = await self._request("PATCH", path, json_body=json_body, headers=headers)
if response.status_code == 204 or not response.content:
return {}
return self._decode_json(response)
async def delete(
self,
path: str,
*,
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
response = await self._request("DELETE", path, headers=headers)
if response.status_code == 204 or not response.content:
return {"deleted": True, "status_code": response.status_code}
return self._decode_json(response)
async def iterate_pages(
self,
path: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> AsyncIterator[dict[str, Any]]:
next_url: str | None = self._resolve_url(path)
next_params = dict(params or {})
while next_url:
response = await self._request(
"GET",
next_url,
params=next_params or None,
headers=headers,
)
payload = self._decode_json(response)
if not isinstance(payload, dict):
raise MicrosoftGraphClientError(
f"Expected paginated Graph response dict, got {type(payload).__name__}."
)
yield payload
next_url = payload.get("@odata.nextLink")
next_params = {}
async def collect_paginated(
self,
path: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> list[Any]:
items: list[Any] = []
async for page in self.iterate_pages(path, params=params, headers=headers):
value = page.get("value")
if isinstance(value, list):
items.extend(value)
return items
async def download_to_file(
self,
path: str,
destination: str | Path,
*,
headers: dict[str, str] | None = None,
chunk_size: int = 65536,
) -> dict[str, Any]:
"""Download a Graph resource to disk, streaming the response body.
The body is written chunk-by-chunk via ``response.aiter_bytes`` with
the ``httpx.AsyncClient`` kept open for the duration of the iteration,
so recordings and other large artifacts do not need to fit in memory.
"""
url = self._resolve_url(path)
target = Path(destination)
target.parent.mkdir(parents=True, exist_ok=True)
tmp_target = target.with_suffix(target.suffix + ".part")
attempt = 0
last_error: Exception | None = None
while attempt <= self.max_retries:
token = await self.token_provider.get_access_token(
force_refresh=attempt > 0 and self._should_refresh_token(last_error)
)
request_headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"User-Agent": self.user_agent,
}
if headers:
request_headers.update(headers)
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(self.timeout),
transport=self._transport,
) as client:
async with client.stream(
"GET",
url,
headers=request_headers,
) as response:
if response.status_code >= 400:
# Materialize error body so we can surface a meaningful
# message; error bodies are small.
await response.aread()
api_error = self._build_api_error("GET", url, response)
last_error = api_error
if (
response.status_code == 401
and attempt < self.max_retries
):
self.token_provider.clear_cache()
await self._sleep(
self._retry_delay(response, attempt)
)
attempt += 1
continue
if (
self._should_retry(response)
and attempt < self.max_retries
):
await self._sleep(
self._retry_delay(response, attempt)
)
attempt += 1
continue
raise api_error
content_type = response.headers.get("content-type")
with tmp_target.open("wb") as handle:
async for chunk in response.aiter_bytes(
chunk_size=chunk_size
):
if chunk:
handle.write(chunk)
except httpx.HTTPError as exc:
last_error = exc
tmp_target.unlink(missing_ok=True)
if attempt >= self.max_retries:
raise MicrosoftGraphClientError(
f"Microsoft Graph download failed for GET {url}: {exc}"
) from exc
await self._sleep(self._retry_delay(None, attempt))
attempt += 1
continue
os.replace(tmp_target, target)
return {
"path": str(target),
"size_bytes": target.stat().st_size,
"content_type": content_type,
}
tmp_target.unlink(missing_ok=True)
raise MicrosoftGraphClientError(
f"Microsoft Graph download exhausted retries for GET {url}."
)
async def _request(
self,
method: str,
path_or_url: str,
*,
params: dict[str, Any] | None = None,
json_body: Any | None = None,
headers: dict[str, str] | None = None,
) -> httpx.Response:
url = self._resolve_url(path_or_url)
attempt = 0
last_error: Exception | None = None
while attempt <= self.max_retries:
token = await self.token_provider.get_access_token(
force_refresh=attempt > 0 and self._should_refresh_token(last_error)
)
request_headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/json",
"User-Agent": self.user_agent,
}
if json_body is not None:
request_headers["Content-Type"] = "application/json"
if headers:
request_headers.update(headers)
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(self.timeout),
transport=self._transport,
) as client:
response = await client.request(
method,
url,
params=params,
json=json_body,
headers=request_headers,
)
except httpx.HTTPError as exc:
last_error = exc
if attempt >= self.max_retries:
raise MicrosoftGraphClientError(
f"Microsoft Graph request failed for {method} {url}: {exc}"
) from exc
await self._sleep(self._retry_delay(None, attempt))
attempt += 1
continue
if response.status_code < 400:
return response
api_error = self._build_api_error(method, url, response)
last_error = api_error
if response.status_code == 401 and attempt < self.max_retries:
self.token_provider.clear_cache()
await self._sleep(self._retry_delay(response, attempt))
attempt += 1
continue
if self._should_retry(response) and attempt < self.max_retries:
await self._sleep(self._retry_delay(response, attempt))
attempt += 1
continue
raise api_error
raise MicrosoftGraphClientError(
f"Microsoft Graph request exhausted retries for {method} {url}."
)
def _resolve_url(self, path_or_url: str) -> str:
if path_or_url.startswith(("http://", "https://")):
return path_or_url
path = path_or_url if path_or_url.startswith("/") else f"/{path_or_url}"
return f"{self.base_url}{path}"
@staticmethod
def _decode_json(response: httpx.Response) -> Any:
try:
return response.json()
except ValueError as exc:
raise MicrosoftGraphClientError(
"Microsoft Graph response was not valid JSON for "
f"{response.request.method} {response.request.url}"
) from exc
@staticmethod
def _should_retry(response: httpx.Response | None) -> bool:
if response is None:
return True
return response.status_code == 429 or 500 <= response.status_code < 600
@staticmethod
def _should_refresh_token(error: Exception | None) -> bool:
return isinstance(error, MicrosoftGraphAPIError) and error.status_code == 401
@staticmethod
def _retry_delay(response: httpx.Response | None, attempt: int) -> float:
if response is not None:
retry_after = response.headers.get("Retry-After")
if retry_after:
try:
return max(0.0, float(retry_after))
except ValueError:
pass
return min(8.0, 0.5 * (2 ** attempt))
@staticmethod
def _build_api_error(
method: str,
url: str,
response: httpx.Response,
) -> MicrosoftGraphAPIError:
payload: Any = None
message = response.text.strip() or "unknown error"
try:
payload = response.json()
except ValueError:
payload = None
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
code = error.get("code")
inner_message = error.get("message")
if code and inner_message:
message = f"{code}: {inner_message}"
elif inner_message:
message = str(inner_message)
elif isinstance(error, str):
message = error
retry_after: float | None = None
header_value = response.headers.get("Retry-After")
if header_value:
try:
retry_after = float(header_value)
except ValueError:
retry_after = None
return MicrosoftGraphAPIError(
response.status_code,
method,
url,
message,
retry_after_seconds=retry_after,
payload=payload,
)