diff --git a/tests/tools/test_microsoft_graph_auth.py b/tests/tools/test_microsoft_graph_auth.py index b969afe47c..4c45ca2c29 100644 --- a/tests/tools/test_microsoft_graph_auth.py +++ b/tests/tools/test_microsoft_graph_auth.py @@ -2,10 +2,13 @@ from __future__ import annotations +import asyncio + import httpx import pytest from tools.microsoft_graph_auth import ( + CachedAccessToken, DEFAULT_GRAPH_SCOPE, GraphCredentials, MicrosoftGraphConfigError, @@ -66,6 +69,33 @@ class TestMicrosoftGraphTokenProvider: assert second == "token-1" assert len(calls) == 1 + async def test_concurrent_calls_share_one_token_fetch(self): + calls: list[int] = [] + + provider = MicrosoftGraphTokenProvider( + GraphCredentials("tenant", "client", "secret"), + ) + + async def _fake_fetch(): + calls.append(1) + await asyncio.sleep(0) + return CachedAccessToken( + access_token="token-1", + token_type="Bearer", + expires_at=9_999_999_999, + ) + + provider._fetch_access_token = _fake_fetch # type: ignore[method-assign] + + first, second = await asyncio.gather( + provider.get_access_token(), + provider.get_access_token(), + ) + + assert first == "token-1" + assert second == "token-1" + assert len(calls) == 1 + async def test_refreshes_when_cached_token_is_expired(self): calls: list[int] = []