diff --git a/tests/tools/test_microsoft_graph_client.py b/tests/tools/test_microsoft_graph_client.py index e788856ff8..b0f6ba31e3 100644 --- a/tests/tools/test_microsoft_graph_client.py +++ b/tests/tools/test_microsoft_graph_client.py @@ -135,6 +135,111 @@ class TestMicrosoftGraphClient: assert result["content_type"] == "video/mp4" assert result["size_bytes"] == len(b"meeting-recording") + async def test_download_to_file_streams_large_payload_in_chunks( + self, tmp_path: Path, monkeypatch + ): + """Recordings can be hundreds of MB; verify the body is streamed. + + Uses a payload larger than the chunk size and counts how many + ``aiter_bytes`` iterations the download loop performs. If the + response were buffered in memory before the loop ran, only one + non-empty chunk would be yielded. + """ + payload = b"x" * (512 * 1024) # 512 KiB + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + content=payload, + headers={"content-type": "video/mp4"}, + ) + + chunk_calls: list[int] = [] + original_aiter_bytes = httpx.Response.aiter_bytes + + async def counting_aiter_bytes(self, chunk_size: int | None = None): + async for chunk in original_aiter_bytes(self, chunk_size): + chunk_calls.append(len(chunk)) + yield chunk + + monkeypatch.setattr(httpx.Response, "aiter_bytes", counting_aiter_bytes) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + ) + destination = tmp_path / "big-recording.mp4" + result = await client.download_to_file( + "/drive/item/content", destination, chunk_size=65536 + ) + + assert destination.read_bytes() == payload + assert result["size_bytes"] == len(payload) + assert len(chunk_calls) >= 2, ( + "Expected multiple chunks; got a single chunk " + f"which suggests the body was buffered: {chunk_calls}" + ) + assert not (tmp_path / "big-recording.mp4.part").exists() + + async def test_download_to_file_retries_on_transient_server_error( + self, tmp_path: Path + ): + calls: list[int] = [] + sleeps: list[float] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(1) + if len(calls) == 1: + return httpx.Response( + 503, json={"error": {"message": "unavailable"}} + ) + return httpx.Response( + 200, + content=b"payload", + headers={"content-type": "application/octet-stream"}, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + sleep=fake_sleep, + max_retries=2, + ) + destination = tmp_path / "artifact.bin" + result = await client.download_to_file("/drive/item/content", destination) + + assert destination.read_bytes() == b"payload" + assert result["size_bytes"] == len(b"payload") + assert len(calls) == 2 + assert sleeps == [0.5] + assert not (tmp_path / "artifact.bin.part").exists() + + async def test_download_to_file_cleans_partial_file_on_exhausted_retries( + self, tmp_path: Path + ): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(503, json={"error": {"message": "unavailable"}}) + + async def fake_sleep(delay: float) -> None: + return None + + client = MicrosoftGraphClient( + _make_provider(), + transport=httpx.MockTransport(handler), + sleep=fake_sleep, + max_retries=1, + ) + destination = tmp_path / "artifact.bin" + + with pytest.raises(MicrosoftGraphAPIError): + await client.download_to_file("/drive/item/content", destination) + + assert not destination.exists() + assert not (tmp_path / "artifact.bin.part").exists() + async def test_invalid_json_response_raises_client_error(self): def handler(request: httpx.Request) -> httpx.Response: return httpx.Response( diff --git a/tools/microsoft_graph_client.py b/tools/microsoft_graph_client.py index f92ba66c54..dbdf211f6e 100644 --- a/tools/microsoft_graph_client.py +++ b/tools/microsoft_graph_client.py @@ -160,20 +160,101 @@ class MicrosoftGraphClient: headers: dict[str, str] | None = None, chunk_size: int = 65536, ) -> dict[str, Any]: - response = await self._request("GET", path, headers=headers) + """Download a Graph resource to disk, streaming the response body. + + The body is written chunk-by-chunk via ``response.aiter_bytes`` with + the ``httpx.AsyncClient`` kept open for the duration of the iteration, + so recordings and other large artifacts do not need to fit in memory. + """ + url = self._resolve_url(path) target = Path(destination) target.parent.mkdir(parents=True, exist_ok=True) tmp_target = target.with_suffix(target.suffix + ".part") - with tmp_target.open("wb") as handle: - async for chunk in response.aiter_bytes(chunk_size=chunk_size): - if chunk: - handle.write(chunk) - os.replace(tmp_target, target) - return { - "path": str(target), - "size_bytes": target.stat().st_size, - "content_type": response.headers.get("content-type"), - } + + attempt = 0 + last_error: Exception | None = None + + while attempt <= self.max_retries: + token = await self.token_provider.get_access_token( + force_refresh=attempt > 0 and self._should_refresh_token(last_error) + ) + request_headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/json", + "User-Agent": self.user_agent, + } + if headers: + request_headers.update(headers) + + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), + transport=self._transport, + ) as client: + async with client.stream( + "GET", + url, + headers=request_headers, + ) as response: + if response.status_code >= 400: + # Materialize error body so we can surface a meaningful + # message; error bodies are small. + await response.aread() + api_error = self._build_api_error("GET", url, response) + last_error = api_error + + if ( + response.status_code == 401 + and attempt < self.max_retries + ): + self.token_provider.clear_cache() + await self._sleep( + self._retry_delay(response, attempt) + ) + attempt += 1 + continue + + if ( + self._should_retry(response) + and attempt < self.max_retries + ): + await self._sleep( + self._retry_delay(response, attempt) + ) + attempt += 1 + continue + + raise api_error + + content_type = response.headers.get("content-type") + with tmp_target.open("wb") as handle: + async for chunk in response.aiter_bytes( + chunk_size=chunk_size + ): + if chunk: + handle.write(chunk) + except httpx.HTTPError as exc: + last_error = exc + tmp_target.unlink(missing_ok=True) + if attempt >= self.max_retries: + raise MicrosoftGraphClientError( + f"Microsoft Graph download failed for GET {url}: {exc}" + ) from exc + await self._sleep(self._retry_delay(None, attempt)) + attempt += 1 + continue + + os.replace(tmp_target, target) + return { + "path": str(target), + "size_bytes": target.stat().st_size, + "content_type": content_type, + } + + tmp_target.unlink(missing_ok=True) + raise MicrosoftGraphClientError( + f"Microsoft Graph download exhausted retries for GET {url}." + ) async def _request( self,