diff --git a/tests/tools/test_mcp_oauth_cold_load_expiry.py b/tests/tools/test_mcp_oauth_cold_load_expiry.py new file mode 100644 index 000000000..a9fb19106 --- /dev/null +++ b/tests/tools/test_mcp_oauth_cold_load_expiry.py @@ -0,0 +1,546 @@ +"""Tests for cold-load token expiry tracking in MCP OAuth. + +PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch) +and 401 dedup, but left two underlying latent bugs in place: + +1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``, + which is meaningless after a process restart. +2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from + storage but does NOT call ``update_token_expiry``, so + ``token_expiry_time`` stays None. ``is_token_valid()`` then returns True + for any loaded token regardless of actual age, and the SDK's preemptive + refresh branch at ``oauth2.py:491`` is never taken. + +Consequence: a token that expired while the process was down ships to the +server with a stale Bearer header. The server's response is provider-specific +— some return HTTP 401 (caught by the consolidation's 401 handler, which +surfaces a ``needs_reauth`` error), others return HTTP 200 with an +application-level auth failure in the body (e.g. BetterStack's "No teams +found. Please check your authentication."), which the consolidation cannot +detect. + +These tests pin the contract for Fix A: +- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp. +- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so + the SDK's ``update_token_expiry`` computes the correct absolute expiry. +- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time`` + after loading, so ``is_token_valid()`` reports True only for tokens that + are actually still valid, and the SDK's preemptive refresh fires for + expired tokens with a live refresh_token. + +Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute +timestamp persisted alongside the access_token (``auth.ts:~180``). +""" +from __future__ import annotations + +import asyncio +import json +import time + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required") + + +# --------------------------------------------------------------------------- +# HermesTokenStorage — absolute expiry persistence +# --------------------------------------------------------------------------- + + +class TestSetTokensAbsoluteExpiry: + def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch): + """Tokens round-tripped through disk must encode absolute expiry.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + before = time.time() + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + ) + after = time.time() + + on_disk = json.loads( + (tmp_path / "mcp-tokens" / "srv.json").read_text() + ) + assert "expires_at" in on_disk, ( + "Fix A: set_tokens must record an absolute expires_at wall-clock " + "timestamp alongside the SDK's serialized token so cold-loads " + "can compute correct remaining TTL." + ) + assert before + 3600 <= on_disk["expires_at"] <= after + 3600 + + def test_set_tokens_without_expires_in_omits_expires_at( + self, tmp_path, monkeypatch + ): + """Tokens without a TTL must not gain a fabricated expires_at.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + refresh_token="r", + ) + ) + ) + + on_disk = json.loads( + (tmp_path / "mcp-tokens" / "srv.json").read_text() + ) + assert "expires_at" not in on_disk + + +class TestGetTokensReconstructsExpiresIn: + def test_get_tokens_uses_expires_at_for_remaining_ttl( + self, tmp_path, monkeypatch + ): + """Round-trip: expires_in on read must reflect time remaining.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + ) + + # Wait briefly so the remaining TTL is measurably less than 3600. + time.sleep(0.05) + + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in is not None + # Should be slightly less than 3600 after the 50ms sleep. + assert 3500 < reloaded.expires_in <= 3600 + + def test_get_tokens_returns_zero_ttl_for_expired_token( + self, tmp_path, monkeypatch + ): + """An already-expired token reloaded from disk must report expires_in=0.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + # Write an already-expired token file directly. + (token_dir / "srv.json").write_text( + json.dumps( + { + "access_token": "a", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": time.time() - 60, # expired 1 min ago + "refresh_token": "r", + } + ) + ) + + storage = HermesTokenStorage("srv") + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in == 0, ( + "Expired token must reload with expires_in=0 so the SDK's " + "is_token_valid() returns False and preemptive refresh fires." + ) + + def test_get_tokens_legacy_file_without_expires_at_is_loadable( + self, tmp_path, monkeypatch + ): + """Existing on-disk files (pre-Fix-A) must still load without crashing. + + Pre-existing token files have ``expires_in`` but no ``expires_at``. + Fix A falls back to the file's mtime as a best-effort wall-clock + proxy: a file whose (mtime + expires_in) is in the past clamps + expires_in to zero so the SDK refreshes on next request. A fresh + legacy-format file (mtime = now) keeps most of its TTL. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + # Legacy-shape file (no expires_at). Make it stale by backdating mtime + # well past its nominal expires_in. + legacy_path = token_dir / "srv.json" + legacy_path.write_text( + json.dumps( + { + "access_token": "a", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "r", + } + ) + ) + stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL + import os + + os.utime(legacy_path, (stale_time, stale_time)) + + storage = HermesTokenStorage("srv") + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in == 0, ( + "Legacy file whose mtime + expires_in is in the past must report " + "expires_in=0 so the SDK refreshes on next request." + ) + + +# --------------------------------------------------------------------------- +# HermesMCPOAuthProvider._initialize — seed token_expiry_time +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_seeds_token_expiry_time_from_stored_tokens( + tmp_path, monkeypatch +): + """Cold-load must populate context.token_expiry_time. + + The SDK's base ``_initialize`` loads current_tokens but doesn't seed + token_expiry_time. Our subclass must do it so ``is_token_valid()`` + reports correctly and the preemptive-refresh path fires when needed. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=7200, + refresh_token="r", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + from mcp.shared.auth import OAuthClientMetadata + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.token_expiry_time is not None, ( + "Fix A: _initialize must seed context.token_expiry_time so " + "is_token_valid() correctly reports expiry on cold-load." + ) + # Should be ~7200s in the future (fresh write). + assert provider.context.token_expiry_time > time.time() + 7000 + assert provider.context.token_expiry_time <= time.time() + 7200 + 5 + + +@pytest.mark.asyncio +async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch): + """After _initialize, an expired-on-disk token must report is_token_valid=False. + + This is the end-to-end assertion: cold-load an expired token, verify the + SDK's own ``is_token_valid()`` now returns False (the consequence of + seeding token_expiry_time correctly), so the SDK's ``async_auth_flow`` + will take the ``can_refresh_token()`` branch on the next request and + silently refresh instead of sending the stale Bearer. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + # Write an already-expired token directly so we control the wall-clock. + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + (token_dir / "srv.json").write_text( + json.dumps( + { + "access_token": "stale", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": time.time() - 60, + "refresh_token": "fresh", + } + ) + ) + + storage = HermesTokenStorage("srv") + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.is_token_valid() is False, ( + "After _initialize with an expired-on-disk token, is_token_valid() " + "must return False so the SDK's async_auth_flow takes the " + "preemptive refresh path." + ) + assert provider.context.can_refresh_token() is True, ( + "Refresh should remain possible because refresh_token + client_info " + "are both present." + ) + + +async def _noop_redirect(_url: str) -> None: + return None + + +async def _noop_callback() -> tuple[str, str | None]: + raise AssertionError("callback handler should not be invoked in these tests") + + +# --------------------------------------------------------------------------- +# Pre-flight OAuth metadata discovery +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_prefetches_oauth_metadata_when_missing( + tmp_path, monkeypatch +): + """Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token`` + has the correct ``token_endpoint`` before the first refresh attempt. + + Without this, the SDK's ``_refresh_token`` falls back to + ``{server_url}/token`` which is wrong for providers whose AS is at + a different origin. BetterStack specifically: MCP at + ``mcp.betterstack.com`` but token_endpoint at + ``betterstack.com/oauth/token``. Without pre-flight the refresh 404s + and we drop into full browser re-auth — visible to the user as an + unwanted OAuth browser prompt every time the process restarts. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + import httpx + from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthToken, + ) + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + # Route the AsyncClient used inside _prefetch_oauth_metadata through a + # MockTransport that mimics BetterStack's split-origin discovery: + # PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com + # ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token + def mock_handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if url.endswith("/.well-known/oauth-protected-resource"): + return httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://auth.example.com"], + "scopes_supported": ["read", "write"], + "bearer_methods_supported": ["header"], + }, + ) + if url.endswith("/.well-known/oauth-authorization-server"): + return httpx.Response( + 200, + json={ + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/oauth/authorize", + "token_endpoint": "https://auth.example.com/oauth/token", + "registration_endpoint": "https://auth.example.com/oauth/register", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": ["none"], + "scopes_supported": ["read", "write"], + }, + ) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + + # Patch the AsyncClient constructor used by _prefetch_oauth_metadata so + # it uses our mock transport instead of the real network. + import httpx as real_httpx + + original_async_client = real_httpx.AsyncClient + + def patched_async_client(*args, **kwargs): + kwargs["transport"] = transport + return original_async_client(*args, **kwargs) + + monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://mcp.example.com", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.protected_resource_metadata is not None, ( + "Pre-flight must cache PRM for the SDK to reference later." + ) + assert provider.context.oauth_metadata is not None, ( + "Pre-flight must cache ASM so _refresh_token builds the correct " + "token_endpoint URL." + ) + assert str(provider.context.oauth_metadata.token_endpoint) == ( + "https://auth.example.com/oauth/token" + ) + + +@pytest.mark.asyncio +async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch): + """Pre-flight must not run when there are no stored tokens yet. + + Without this guard, every fresh-install ``_initialize`` would do two + extra network roundtrips that gain nothing (the SDK's 401-branch + discovery will run on the first real request anyway). + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + import httpx + from mcp.shared.auth import OAuthClientMetadata + from pydantic import AnyUrl + + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + from tools.mcp_oauth import HermesTokenStorage + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + calls: list[str] = [] + + def mock_handler(request: httpx.Request) -> httpx.Response: + calls.append(str(request.url)) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + import httpx as real_httpx + + original = real_httpx.AsyncClient + + def patched(*args, **kwargs): + kwargs["transport"] = transport + return original(*args, **kwargs) + + monkeypatch.setattr(real_httpx, "AsyncClient", patched) + + storage = HermesTokenStorage("srv") # empty — no tokens on disk + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://mcp.example.com", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert calls == [], ( + f"Pre-flight must not fire when no tokens are stored, but got {calls}" + ) diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index 6e1d7f5fb..7910c3cdc 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -40,6 +40,7 @@ import re import socket import sys import threading +import time import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path @@ -196,6 +197,35 @@ class HermesTokenStorage: data = _read_json(self._tokens_path()) if data is None: return None + # Hermes records an absolute wall-clock ``expires_at`` alongside the + # SDK's serialized token (see ``set_tokens``). On read we rewrite + # ``expires_in`` to the remaining seconds so the SDK's downstream + # ``update_token_expiry`` computes the correct absolute time and + # ``is_token_valid()`` correctly reports False for tokens that + # expired while the process was down. + # + # Legacy token files (pre-Fix-A) have ``expires_in`` but no + # ``expires_at``. We fall back to the file's mtime as a best-effort + # wall-clock proxy for when the token was written: if (mtime + + # expires_in) is in the past, clamp ``expires_in`` to zero so the + # SDK refreshes before the first request. This self-heals one-time + # on the next successful ``set_tokens``, which writes the new + # ``expires_at`` field. The stored ``expires_at`` is stripped before + # model_validate because it's not part of the SDK's OAuthToken schema. + absolute_expiry = data.pop("expires_at", None) + if absolute_expiry is not None: + data["expires_in"] = int(max(absolute_expiry - time.time(), 0)) + elif data.get("expires_in") is not None: + try: + file_mtime = self._tokens_path().stat().st_mtime + except OSError: + file_mtime = None + if file_mtime is not None: + try: + implied_expiry = file_mtime + int(data["expires_in"]) + data["expires_in"] = int(max(implied_expiry - time.time(), 0)) + except (TypeError, ValueError): + pass try: return OAuthToken.model_validate(data) except (ValueError, TypeError, KeyError) as exc: @@ -203,7 +233,23 @@ class HermesTokenStorage: return None async def set_tokens(self, tokens: "OAuthToken") -> None: - _write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) + payload = tokens.model_dump(exclude_none=True) + # Persist an absolute ``expires_at`` so a process restart can + # reconstruct the correct remaining TTL. Without this the MCP SDK's + # ``_initialize`` reloads a relative ``expires_in`` which has no + # wall-clock reference, leaving ``context.token_expiry_time=None`` + # and ``is_token_valid()`` falsely reporting True. See Fix A in + # ``mcp-oauth-token-diagnosis`` skill + Claude Code's + # ``OAuthTokens.expiresAt`` persistence (auth.ts ~180). + expires_in = payload.get("expires_in") + if expires_in is not None: + try: + payload["expires_at"] = time.time() + int(expires_in) + except (TypeError, ValueError): + # Mock tokens or unusual shapes: skip the expires_at write + # rather than fail persistence. + pass + _write_json(self._tokens_path(), payload) logger.debug("OAuth tokens saved for %s", self._server_name) # -- client info ------------------------------------------------------- diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py index 29cd79b78..7c8a91f3f 100644 --- a/tools/mcp_oauth_manager.py +++ b/tools/mcp_oauth_manager.py @@ -111,6 +111,131 @@ def _make_hermes_provider_class() -> Optional[type]: super().__init__(*args, **kwargs) self._hermes_server_name = server_name + async def _initialize(self) -> None: + """Load stored tokens + client info AND seed token_expiry_time. + + Also eagerly fetches OAuth authorization-server metadata (PRM + + ASM) when we have stored tokens but no cached metadata, so the + SDK's ``_refresh_token`` can build the correct token_endpoint + URL on the preemptive-refresh path. Without this, the SDK + falls back to ``{mcp_server_url}/token`` (wrong for providers + whose AS is a different origin — BetterStack's MCP lives at + ``https://mcp.betterstack.com`` but its token endpoint is at + ``https://betterstack.com/oauth/token``), the refresh 404s, and + we drop through to full browser reauth. + + The SDK's base ``_initialize`` populates ``current_tokens`` but + does NOT call ``update_token_expiry``, so ``token_expiry_time`` + stays ``None`` and ``is_token_valid()`` returns True for any + loaded token regardless of actual age. After a process restart + this ships stale Bearer tokens to the server; some providers + return HTTP 401 (caught by the 401 handler), others return 200 + with an app-level auth error (invisible to the transport layer, + e.g. BetterStack returning "No teams found. Please check your + authentication."). + + Seeding ``token_expiry_time`` from the reloaded token fixes that: + ``is_token_valid()`` correctly reports False for expired tokens, + ``async_auth_flow`` takes the ``can_refresh_token()`` branch, + and the SDK quietly refreshes before the first real request. + + Paired with :class:`HermesTokenStorage` persisting an absolute + ``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the + remaining TTL we compute here reflects real wall-clock age. + """ + await super()._initialize() + tokens = self.context.current_tokens + if tokens is not None and tokens.expires_in is not None: + self.context.update_token_expiry(tokens) + + # Pre-flight OAuth AS discovery so ``_refresh_token`` has a + # correct ``token_endpoint`` before the first refresh attempt. + # Only runs when we have tokens on cold-load but no cached + # metadata — i.e. the exact scenario where the SDK's built-in + # 401-branch discovery hasn't had a chance to run yet. + if ( + tokens is not None + and self.context.oauth_metadata is None + ): + try: + await self._prefetch_oauth_metadata() + except Exception as exc: # pragma: no cover — defensive + # Non-fatal: if discovery fails, the SDK's normal 401- + # branch discovery will run on the next request. + logger.debug( + "MCP OAuth '%s': pre-flight metadata discovery " + "failed (non-fatal): %s", + self._hermes_server_name, exc, + ) + + async def _prefetch_oauth_metadata(self) -> None: + """Fetch PRM + ASM from the well-known endpoints, cache on context. + + Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551) + but runs synchronously before the first request instead of + inside the httpx auth_flow generator. Uses the SDK's own URL + builders and response handlers so we track whatever the SDK + version we're pinned to expects. + """ + import httpx # local import: httpx is an MCP SDK dependency + from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + create_oauth_metadata_request, + handle_auth_metadata_response, + handle_protected_resource_response, + ) + + server_url = self.context.server_url + async with httpx.AsyncClient(timeout=10.0) as client: + # Step 1: PRM discovery to learn the authorization_server URL. + for url in build_protected_resource_metadata_discovery_urls( + None, server_url + ): + req = create_oauth_metadata_request(url) + try: + resp = await client.send(req) + except httpx.HTTPError as exc: + logger.debug( + "MCP OAuth '%s': PRM discovery to %s failed: %s", + self._hermes_server_name, url, exc, + ) + continue + prm = await handle_protected_resource_response(resp) + if prm: + self.context.protected_resource_metadata = prm + if prm.authorization_servers: + self.context.auth_server_url = str( + prm.authorization_servers[0] + ) + break + + # Step 2: ASM discovery against the auth_server_url (or + # server_url fallback for legacy providers). + for url in build_oauth_authorization_server_metadata_discovery_urls( + self.context.auth_server_url, server_url + ): + req = create_oauth_metadata_request(url) + try: + resp = await client.send(req) + except httpx.HTTPError as exc: + logger.debug( + "MCP OAuth '%s': ASM discovery to %s failed: %s", + self._hermes_server_name, url, exc, + ) + continue + ok, asm = await handle_auth_metadata_response(resp) + if not ok: + break + if asm: + self.context.oauth_metadata = asm + logger.debug( + "MCP OAuth '%s': pre-flight ASM discovered " + "token_endpoint=%s", + self._hermes_server_name, asm.token_endpoint, + ) + break + async def async_auth_flow(self, request): # type: ignore[override] # Pre-flow hook: ask the manager to refresh from disk if needed. # Any failure here is non-fatal — we just log and proceed with