From a3a49324052c262c44954b229b781e9703ae5845 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 19 Apr 2026 16:31:07 -0700 Subject: [PATCH] fix(mcp-oauth): bidirectional auth_flow bridge + absolute expires_at (salvage #12025) (#12717) * [verified] fix(mcp-oauth): bridge httpx auth_flow bidirectional generator HermesMCPOAuthProvider.async_auth_flow wrapped the SDK's auth_flow with 'async for item in super().async_auth_flow(request): yield item', which discards httpx's .asend(response) values and resumes the inner generator with None. This broke every OAuth MCP server on the first HTTP response with 'NoneType' object has no attribute 'status_code' crashing at mcp/client/auth/oauth2.py:505. Replace with a manual bridge that forwards .asend() values into the inner generator, preserving httpx's bidirectional auth_flow contract. Add tests/tools/test_mcp_oauth_bidirectional.py with two regression tests that drive the flow through real .asend() round-trips. These catch the bug at the unit level; prior tests only exercised _initialize() and disk-watching, never the full generator protocol. Verified against BetterStack MCP: Before: 'Connection failed (11564ms): NoneType...' after 3 retries After: 'Connected (2416ms); Tools discovered: 83' Regression from #11383. * [verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load PR #11383's consolidation fixed external-refresh reloading and 401 dedup but left two latent bugs that surfaced on BetterStack and any other OAuth MCP with a split-origin authorization server: 1. HermesTokenStorage persisted only a relative 'expires_in', which is meaningless after a process restart. The MCP SDK's OAuthContext does NOT seed token_expiry_time in _initialize, so is_token_valid() returned True for any reloaded token regardless of age. Expired tokens shipped to servers, and app-level auth failures (e.g. BetterStack's 'No teams found. Please check your authentication.') were invisible to the transport-layer 401 handler. 2. Even once preemptive refresh did fire, the SDK's _refresh_token falls back to {server_url}/token when oauth_metadata isn't cached. For providers whose AS is at a different origin (BetterStack: mcp.betterstack.com for MCP, betterstack.com/oauth/token for the token endpoint), that fallback 404s and drops into full browser re-auth on every process restart. Fix set: - HermesTokenStorage.set_tokens persists an absolute wall-clock expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL at write time). - HermesTokenStorage.get_tokens reconstructs expires_in from max(expires_at - now, 0), clamping expired tokens to zero TTL. Legacy files without expires_at fall back to file-mtime as a best-effort wall-clock proxy, self-healing on the next set_tokens. - HermesMCPOAuthProvider._initialize calls super(), then update_token_expiry on the reloaded tokens so token_expiry_time reflects actual remaining TTL. If tokens are loaded but oauth_metadata is missing, pre-flight PRM + ASM discovery runs via httpx.AsyncClient using the MCP SDK's own URL builders and response handlers (build_protected_resource_metadata_discovery_urls, handle_auth_metadata_response, etc.) so the SDK sees the correct token_endpoint before the first refresh attempt. Pre-flight is skipped when there are no stored tokens to keep fresh-install paths zero-cost. Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py): - set_tokens persists absolute expires_at - set_tokens skips expires_at when token has no expires_in - get_tokens round-trips expires_at -> remaining expires_in - expired tokens reload with expires_in=0 - legacy files without expires_at fall back to mtime proxy - _initialize seeds token_expiry_time from stored tokens - _initialize flags expired-on-disk tokens as is_token_valid=False - _initialize pre-flights PRM + ASM discovery with mock transport - _initialize skips pre-flight when no tokens are stored Verified against BetterStack MCP: hermes mcp test betterstack -> Connected (2508ms), 83 tools mcp_betterstack_telemetry_list_teams_tool -> real team data, not 'No teams found. Please check your authentication.' Reference: mcp-oauth-token-diagnosis skill, Fix A. * chore: map hermes@noushq.ai to benbarclay in AUTHOR_MAP Needed for CI attribution check on cherry-picked commits from PR #12025. --------- Co-authored-by: Hermes Agent --- scripts/release.py | 1 + tests/tools/test_mcp_oauth_bidirectional.py | 210 +++++++ .../tools/test_mcp_oauth_cold_load_expiry.py | 546 ++++++++++++++++++ tools/mcp_oauth.py | 48 +- tools/mcp_oauth_manager.py | 150 ++++- 5 files changed, 951 insertions(+), 4 deletions(-) create mode 100644 tests/tools/test_mcp_oauth_bidirectional.py create mode 100644 tests/tools/test_mcp_oauth_cold_load_expiry.py diff --git a/scripts/release.py b/scripts/release.py index 3f7930e77..ca41ef93c 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -132,6 +132,7 @@ AUTHOR_MAP = { "bryan@intertwinesys.com": "bryanyoung", "christo.mitov@gmail.com": "christomitov", "hermes@nousresearch.com": "NousResearch", + "hermes@noushq.ai": "benbarclay", "chinmingcock@gmail.com": "ChimingLiu", "openclaw@sparklab.ai": "openclaw", "semihcvlk53@gmail.com": "Himess", diff --git a/tests/tools/test_mcp_oauth_bidirectional.py b/tests/tools/test_mcp_oauth_bidirectional.py new file mode 100644 index 000000000..37ca409bb --- /dev/null +++ b/tests/tools/test_mcp_oauth_bidirectional.py @@ -0,0 +1,210 @@ +"""Regression test for the ``HermesMCPOAuthProvider.async_auth_flow`` bidirectional +generator bridge. + +PR #11383 introduced a subclass method that wrapped the SDK's ``auth_flow`` with:: + + async for item in super().async_auth_flow(request): + yield item + +``httpx``'s auth_flow contract is a **bidirectional** async generator — the +driving code (``httpx._client._send_handling_auth``) does:: + + next_request = await auth_flow.asend(response) + +to feed HTTP responses back into the generator. The naive ``async for ...`` +wrapper discards those ``.asend(response)`` values and resumes the inner +generator with ``None``, so the SDK's ``response = yield request`` branch in +``mcp/client/auth/oauth2.py`` sees ``response = None`` and crashes at +``if response.status_code == 401`` with +``AttributeError: 'NoneType' object has no attribute 'status_code'``. + +This broke every OAuth MCP server on the first HTTP response regardless of +status code. The reason nothing caught it in CI: zero existing tests drive +the full ``.asend()`` round-trip — the integration tests in +``test_mcp_oauth_integration.py`` stop at ``_initialize()`` and disk-watching. + +These tests drive the wrapper through a manual ``.asend()`` sequence to prove +the bridge forwards responses correctly into the inner SDK generator. +""" +from __future__ import annotations + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required") + + +@pytest.mark.asyncio +async def test_hermes_provider_forwards_asend_values(tmp_path, monkeypatch): + """The wrapper MUST forward ``.asend(response)`` into the inner generator. + + This is the primary regression test. With the broken wrapper, the inner + SDK generator sees ``response = None`` and raises ``AttributeError`` at + ``oauth2.py:505``. With the correct bridge, a 200 response finishes the + flow cleanly (``StopAsyncIteration``). + """ + import httpx + from mcp.shared.auth import OAuthClientMetadata, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None, "SDK OAuth types must be available" + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + reset_manager_for_tests() + + # Seed a valid-looking token so the SDK's _initialize loads something and + # can_refresh_token() is True (though we don't exercise refresh here — we + # go straight through the 200 path). + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="old_access", + token_type="Bearer", + expires_in=3600, + refresh_token="old_refresh", + ) + ) + # Also seed client_info so the SDK doesn't attempt registration. + from mcp.shared.auth import OAuthClientInformationFull + + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + req = httpx.Request("POST", "https://example.com/mcp") + flow = provider.async_auth_flow(req) + + # First anext() drives the wrapper + inner generator until the inner + # yields the outbound request (at oauth2.py:503 ``response = yield request``). + outbound = await flow.__anext__() + assert outbound is not None, "wrapper must yield the outbound request" + assert outbound.url.host == "example.com" + + # Simulate httpx returning a 200 response. + fake_response = httpx.Response(200, request=outbound) + + # The broken wrapper would crash here with AttributeError: 'NoneType' + # object has no attribute 'status_code', because the SDK's inner generator + # resumes with response=None and dereferences .status_code at line 505. + # + # The correct wrapper forwards the response, the SDK takes the non-401 + # non-403 exit, and the generator ends cleanly (StopAsyncIteration). + with pytest.raises(StopAsyncIteration): + await flow.asend(fake_response) + + +@pytest.mark.asyncio +async def test_hermes_provider_forwards_401_triggers_refresh(tmp_path, monkeypatch): + """A 401 response MUST flow into the inner generator and trigger the + SDK's 401 recovery branch. + + With the broken wrapper, the inner generator sees ``response = None`` + and the 401 check short-circuits into AttributeError. With the correct + bridge, the 401 is routed into the SDK's ``response.status_code == 401`` + branch which begins discovery (yielding a metadata-discovery request). + """ + import httpx + from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="old_access", + token_type="Bearer", + expires_in=3600, + refresh_token="old_refresh", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + req = httpx.Request("POST", "https://example.com/mcp") + flow = provider.async_auth_flow(req) + + # Drive to the first yield (outbound MCP request). + outbound = await flow.__anext__() + + # Reply with a 401 including a minimal WWW-Authenticate so the SDK's + # 401 branch can parse resource metadata from it. We just need something + # the SDK accepts before it tries to yield the metadata-discovery request. + fake_401 = httpx.Response( + 401, + request=outbound, + headers={"www-authenticate": 'Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource"'}, + ) + + # The correct bridge forwards the 401 into the SDK; the SDK then yields + # its NEXT request (a metadata-discovery GET). We assert we get a request + # back — any request. The broken bridge would have crashed with + # AttributeError before we ever reach this point. + next_request = await flow.asend(fake_401) + assert isinstance(next_request, httpx.Request), ( + "wrapper must forward .asend() so the SDK's 401 branch can yield the " + "next request in the discovery flow" + ) + + # Clean up the generator — we don't need to complete the full dance. + await flow.aclose() + + +async def _noop_redirect(_url: str) -> None: + """Redirect handler that does nothing (won't be invoked in these tests).""" + return None + + +async def _noop_callback() -> tuple[str, str | None]: + """Callback handler that won't be invoked in these tests.""" + raise AssertionError( + "callback handler should not be invoked in bidirectional-generator tests" + ) diff --git a/tests/tools/test_mcp_oauth_cold_load_expiry.py b/tests/tools/test_mcp_oauth_cold_load_expiry.py new file mode 100644 index 000000000..a9fb19106 --- /dev/null +++ b/tests/tools/test_mcp_oauth_cold_load_expiry.py @@ -0,0 +1,546 @@ +"""Tests for cold-load token expiry tracking in MCP OAuth. + +PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch) +and 401 dedup, but left two underlying latent bugs in place: + +1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``, + which is meaningless after a process restart. +2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from + storage but does NOT call ``update_token_expiry``, so + ``token_expiry_time`` stays None. ``is_token_valid()`` then returns True + for any loaded token regardless of actual age, and the SDK's preemptive + refresh branch at ``oauth2.py:491`` is never taken. + +Consequence: a token that expired while the process was down ships to the +server with a stale Bearer header. The server's response is provider-specific +— some return HTTP 401 (caught by the consolidation's 401 handler, which +surfaces a ``needs_reauth`` error), others return HTTP 200 with an +application-level auth failure in the body (e.g. BetterStack's "No teams +found. Please check your authentication."), which the consolidation cannot +detect. + +These tests pin the contract for Fix A: +- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp. +- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so + the SDK's ``update_token_expiry`` computes the correct absolute expiry. +- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time`` + after loading, so ``is_token_valid()`` reports True only for tokens that + are actually still valid, and the SDK's preemptive refresh fires for + expired tokens with a live refresh_token. + +Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute +timestamp persisted alongside the access_token (``auth.ts:~180``). +""" +from __future__ import annotations + +import asyncio +import json +import time + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required") + + +# --------------------------------------------------------------------------- +# HermesTokenStorage — absolute expiry persistence +# --------------------------------------------------------------------------- + + +class TestSetTokensAbsoluteExpiry: + def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch): + """Tokens round-tripped through disk must encode absolute expiry.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + before = time.time() + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + ) + after = time.time() + + on_disk = json.loads( + (tmp_path / "mcp-tokens" / "srv.json").read_text() + ) + assert "expires_at" in on_disk, ( + "Fix A: set_tokens must record an absolute expires_at wall-clock " + "timestamp alongside the SDK's serialized token so cold-loads " + "can compute correct remaining TTL." + ) + assert before + 3600 <= on_disk["expires_at"] <= after + 3600 + + def test_set_tokens_without_expires_in_omits_expires_at( + self, tmp_path, monkeypatch + ): + """Tokens without a TTL must not gain a fabricated expires_at.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + refresh_token="r", + ) + ) + ) + + on_disk = json.loads( + (tmp_path / "mcp-tokens" / "srv.json").read_text() + ) + assert "expires_at" not in on_disk + + +class TestGetTokensReconstructsExpiresIn: + def test_get_tokens_uses_expires_at_for_remaining_ttl( + self, tmp_path, monkeypatch + ): + """Round-trip: expires_in on read must reflect time remaining.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + ) + + # Wait briefly so the remaining TTL is measurably less than 3600. + time.sleep(0.05) + + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in is not None + # Should be slightly less than 3600 after the 50ms sleep. + assert 3500 < reloaded.expires_in <= 3600 + + def test_get_tokens_returns_zero_ttl_for_expired_token( + self, tmp_path, monkeypatch + ): + """An already-expired token reloaded from disk must report expires_in=0.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + # Write an already-expired token file directly. + (token_dir / "srv.json").write_text( + json.dumps( + { + "access_token": "a", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": time.time() - 60, # expired 1 min ago + "refresh_token": "r", + } + ) + ) + + storage = HermesTokenStorage("srv") + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in == 0, ( + "Expired token must reload with expires_in=0 so the SDK's " + "is_token_valid() returns False and preemptive refresh fires." + ) + + def test_get_tokens_legacy_file_without_expires_at_is_loadable( + self, tmp_path, monkeypatch + ): + """Existing on-disk files (pre-Fix-A) must still load without crashing. + + Pre-existing token files have ``expires_in`` but no ``expires_at``. + Fix A falls back to the file's mtime as a best-effort wall-clock + proxy: a file whose (mtime + expires_in) is in the past clamps + expires_in to zero so the SDK refreshes on next request. A fresh + legacy-format file (mtime = now) keeps most of its TTL. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + # Legacy-shape file (no expires_at). Make it stale by backdating mtime + # well past its nominal expires_in. + legacy_path = token_dir / "srv.json" + legacy_path.write_text( + json.dumps( + { + "access_token": "a", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "r", + } + ) + ) + stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL + import os + + os.utime(legacy_path, (stale_time, stale_time)) + + storage = HermesTokenStorage("srv") + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in == 0, ( + "Legacy file whose mtime + expires_in is in the past must report " + "expires_in=0 so the SDK refreshes on next request." + ) + + +# --------------------------------------------------------------------------- +# HermesMCPOAuthProvider._initialize — seed token_expiry_time +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_seeds_token_expiry_time_from_stored_tokens( + tmp_path, monkeypatch +): + """Cold-load must populate context.token_expiry_time. + + The SDK's base ``_initialize`` loads current_tokens but doesn't seed + token_expiry_time. Our subclass must do it so ``is_token_valid()`` + reports correctly and the preemptive-refresh path fires when needed. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=7200, + refresh_token="r", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + from mcp.shared.auth import OAuthClientMetadata + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.token_expiry_time is not None, ( + "Fix A: _initialize must seed context.token_expiry_time so " + "is_token_valid() correctly reports expiry on cold-load." + ) + # Should be ~7200s in the future (fresh write). + assert provider.context.token_expiry_time > time.time() + 7000 + assert provider.context.token_expiry_time <= time.time() + 7200 + 5 + + +@pytest.mark.asyncio +async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch): + """After _initialize, an expired-on-disk token must report is_token_valid=False. + + This is the end-to-end assertion: cold-load an expired token, verify the + SDK's own ``is_token_valid()`` now returns False (the consequence of + seeding token_expiry_time correctly), so the SDK's ``async_auth_flow`` + will take the ``can_refresh_token()`` branch on the next request and + silently refresh instead of sending the stale Bearer. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + # Write an already-expired token directly so we control the wall-clock. + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + (token_dir / "srv.json").write_text( + json.dumps( + { + "access_token": "stale", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": time.time() - 60, + "refresh_token": "fresh", + } + ) + ) + + storage = HermesTokenStorage("srv") + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.is_token_valid() is False, ( + "After _initialize with an expired-on-disk token, is_token_valid() " + "must return False so the SDK's async_auth_flow takes the " + "preemptive refresh path." + ) + assert provider.context.can_refresh_token() is True, ( + "Refresh should remain possible because refresh_token + client_info " + "are both present." + ) + + +async def _noop_redirect(_url: str) -> None: + return None + + +async def _noop_callback() -> tuple[str, str | None]: + raise AssertionError("callback handler should not be invoked in these tests") + + +# --------------------------------------------------------------------------- +# Pre-flight OAuth metadata discovery +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_prefetches_oauth_metadata_when_missing( + tmp_path, monkeypatch +): + """Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token`` + has the correct ``token_endpoint`` before the first refresh attempt. + + Without this, the SDK's ``_refresh_token`` falls back to + ``{server_url}/token`` which is wrong for providers whose AS is at + a different origin. BetterStack specifically: MCP at + ``mcp.betterstack.com`` but token_endpoint at + ``betterstack.com/oauth/token``. Without pre-flight the refresh 404s + and we drop into full browser re-auth — visible to the user as an + unwanted OAuth browser prompt every time the process restarts. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + import httpx + from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthToken, + ) + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + # Route the AsyncClient used inside _prefetch_oauth_metadata through a + # MockTransport that mimics BetterStack's split-origin discovery: + # PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com + # ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token + def mock_handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if url.endswith("/.well-known/oauth-protected-resource"): + return httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://auth.example.com"], + "scopes_supported": ["read", "write"], + "bearer_methods_supported": ["header"], + }, + ) + if url.endswith("/.well-known/oauth-authorization-server"): + return httpx.Response( + 200, + json={ + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/oauth/authorize", + "token_endpoint": "https://auth.example.com/oauth/token", + "registration_endpoint": "https://auth.example.com/oauth/register", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": ["none"], + "scopes_supported": ["read", "write"], + }, + ) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + + # Patch the AsyncClient constructor used by _prefetch_oauth_metadata so + # it uses our mock transport instead of the real network. + import httpx as real_httpx + + original_async_client = real_httpx.AsyncClient + + def patched_async_client(*args, **kwargs): + kwargs["transport"] = transport + return original_async_client(*args, **kwargs) + + monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://mcp.example.com", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.protected_resource_metadata is not None, ( + "Pre-flight must cache PRM for the SDK to reference later." + ) + assert provider.context.oauth_metadata is not None, ( + "Pre-flight must cache ASM so _refresh_token builds the correct " + "token_endpoint URL." + ) + assert str(provider.context.oauth_metadata.token_endpoint) == ( + "https://auth.example.com/oauth/token" + ) + + +@pytest.mark.asyncio +async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch): + """Pre-flight must not run when there are no stored tokens yet. + + Without this guard, every fresh-install ``_initialize`` would do two + extra network roundtrips that gain nothing (the SDK's 401-branch + discovery will run on the first real request anyway). + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + import httpx + from mcp.shared.auth import OAuthClientMetadata + from pydantic import AnyUrl + + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + from tools.mcp_oauth import HermesTokenStorage + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + calls: list[str] = [] + + def mock_handler(request: httpx.Request) -> httpx.Response: + calls.append(str(request.url)) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + import httpx as real_httpx + + original = real_httpx.AsyncClient + + def patched(*args, **kwargs): + kwargs["transport"] = transport + return original(*args, **kwargs) + + monkeypatch.setattr(real_httpx, "AsyncClient", patched) + + storage = HermesTokenStorage("srv") # empty — no tokens on disk + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://mcp.example.com", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert calls == [], ( + f"Pre-flight must not fire when no tokens are stored, but got {calls}" + ) diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index 6e1d7f5fb..7910c3cdc 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -40,6 +40,7 @@ import re import socket import sys import threading +import time import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path @@ -196,6 +197,35 @@ class HermesTokenStorage: data = _read_json(self._tokens_path()) if data is None: return None + # Hermes records an absolute wall-clock ``expires_at`` alongside the + # SDK's serialized token (see ``set_tokens``). On read we rewrite + # ``expires_in`` to the remaining seconds so the SDK's downstream + # ``update_token_expiry`` computes the correct absolute time and + # ``is_token_valid()`` correctly reports False for tokens that + # expired while the process was down. + # + # Legacy token files (pre-Fix-A) have ``expires_in`` but no + # ``expires_at``. We fall back to the file's mtime as a best-effort + # wall-clock proxy for when the token was written: if (mtime + + # expires_in) is in the past, clamp ``expires_in`` to zero so the + # SDK refreshes before the first request. This self-heals one-time + # on the next successful ``set_tokens``, which writes the new + # ``expires_at`` field. The stored ``expires_at`` is stripped before + # model_validate because it's not part of the SDK's OAuthToken schema. + absolute_expiry = data.pop("expires_at", None) + if absolute_expiry is not None: + data["expires_in"] = int(max(absolute_expiry - time.time(), 0)) + elif data.get("expires_in") is not None: + try: + file_mtime = self._tokens_path().stat().st_mtime + except OSError: + file_mtime = None + if file_mtime is not None: + try: + implied_expiry = file_mtime + int(data["expires_in"]) + data["expires_in"] = int(max(implied_expiry - time.time(), 0)) + except (TypeError, ValueError): + pass try: return OAuthToken.model_validate(data) except (ValueError, TypeError, KeyError) as exc: @@ -203,7 +233,23 @@ class HermesTokenStorage: return None async def set_tokens(self, tokens: "OAuthToken") -> None: - _write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) + payload = tokens.model_dump(exclude_none=True) + # Persist an absolute ``expires_at`` so a process restart can + # reconstruct the correct remaining TTL. Without this the MCP SDK's + # ``_initialize`` reloads a relative ``expires_in`` which has no + # wall-clock reference, leaving ``context.token_expiry_time=None`` + # and ``is_token_valid()`` falsely reporting True. See Fix A in + # ``mcp-oauth-token-diagnosis`` skill + Claude Code's + # ``OAuthTokens.expiresAt`` persistence (auth.ts ~180). + expires_in = payload.get("expires_in") + if expires_in is not None: + try: + payload["expires_at"] = time.time() + int(expires_in) + except (TypeError, ValueError): + # Mock tokens or unusual shapes: skip the expires_at write + # rather than fail persistence. + pass + _write_json(self._tokens_path(), payload) logger.debug("OAuth tokens saved for %s", self._server_name) # -- client info ------------------------------------------------------- diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py index d3760e3b8..7c8a91f3f 100644 --- a/tools/mcp_oauth_manager.py +++ b/tools/mcp_oauth_manager.py @@ -111,6 +111,131 @@ def _make_hermes_provider_class() -> Optional[type]: super().__init__(*args, **kwargs) self._hermes_server_name = server_name + async def _initialize(self) -> None: + """Load stored tokens + client info AND seed token_expiry_time. + + Also eagerly fetches OAuth authorization-server metadata (PRM + + ASM) when we have stored tokens but no cached metadata, so the + SDK's ``_refresh_token`` can build the correct token_endpoint + URL on the preemptive-refresh path. Without this, the SDK + falls back to ``{mcp_server_url}/token`` (wrong for providers + whose AS is a different origin — BetterStack's MCP lives at + ``https://mcp.betterstack.com`` but its token endpoint is at + ``https://betterstack.com/oauth/token``), the refresh 404s, and + we drop through to full browser reauth. + + The SDK's base ``_initialize`` populates ``current_tokens`` but + does NOT call ``update_token_expiry``, so ``token_expiry_time`` + stays ``None`` and ``is_token_valid()`` returns True for any + loaded token regardless of actual age. After a process restart + this ships stale Bearer tokens to the server; some providers + return HTTP 401 (caught by the 401 handler), others return 200 + with an app-level auth error (invisible to the transport layer, + e.g. BetterStack returning "No teams found. Please check your + authentication."). + + Seeding ``token_expiry_time`` from the reloaded token fixes that: + ``is_token_valid()`` correctly reports False for expired tokens, + ``async_auth_flow`` takes the ``can_refresh_token()`` branch, + and the SDK quietly refreshes before the first real request. + + Paired with :class:`HermesTokenStorage` persisting an absolute + ``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the + remaining TTL we compute here reflects real wall-clock age. + """ + await super()._initialize() + tokens = self.context.current_tokens + if tokens is not None and tokens.expires_in is not None: + self.context.update_token_expiry(tokens) + + # Pre-flight OAuth AS discovery so ``_refresh_token`` has a + # correct ``token_endpoint`` before the first refresh attempt. + # Only runs when we have tokens on cold-load but no cached + # metadata — i.e. the exact scenario where the SDK's built-in + # 401-branch discovery hasn't had a chance to run yet. + if ( + tokens is not None + and self.context.oauth_metadata is None + ): + try: + await self._prefetch_oauth_metadata() + except Exception as exc: # pragma: no cover — defensive + # Non-fatal: if discovery fails, the SDK's normal 401- + # branch discovery will run on the next request. + logger.debug( + "MCP OAuth '%s': pre-flight metadata discovery " + "failed (non-fatal): %s", + self._hermes_server_name, exc, + ) + + async def _prefetch_oauth_metadata(self) -> None: + """Fetch PRM + ASM from the well-known endpoints, cache on context. + + Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551) + but runs synchronously before the first request instead of + inside the httpx auth_flow generator. Uses the SDK's own URL + builders and response handlers so we track whatever the SDK + version we're pinned to expects. + """ + import httpx # local import: httpx is an MCP SDK dependency + from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + create_oauth_metadata_request, + handle_auth_metadata_response, + handle_protected_resource_response, + ) + + server_url = self.context.server_url + async with httpx.AsyncClient(timeout=10.0) as client: + # Step 1: PRM discovery to learn the authorization_server URL. + for url in build_protected_resource_metadata_discovery_urls( + None, server_url + ): + req = create_oauth_metadata_request(url) + try: + resp = await client.send(req) + except httpx.HTTPError as exc: + logger.debug( + "MCP OAuth '%s': PRM discovery to %s failed: %s", + self._hermes_server_name, url, exc, + ) + continue + prm = await handle_protected_resource_response(resp) + if prm: + self.context.protected_resource_metadata = prm + if prm.authorization_servers: + self.context.auth_server_url = str( + prm.authorization_servers[0] + ) + break + + # Step 2: ASM discovery against the auth_server_url (or + # server_url fallback for legacy providers). + for url in build_oauth_authorization_server_metadata_discovery_urls( + self.context.auth_server_url, server_url + ): + req = create_oauth_metadata_request(url) + try: + resp = await client.send(req) + except httpx.HTTPError as exc: + logger.debug( + "MCP OAuth '%s': ASM discovery to %s failed: %s", + self._hermes_server_name, url, exc, + ) + continue + ok, asm = await handle_auth_metadata_response(resp) + if not ok: + break + if asm: + self.context.oauth_metadata = asm + logger.debug( + "MCP OAuth '%s': pre-flight ASM discovered " + "token_endpoint=%s", + self._hermes_server_name, asm.token_endpoint, + ) + break + async def async_auth_flow(self, request): # type: ignore[override] # Pre-flow hook: ask the manager to refresh from disk if needed. # Any failure here is non-fatal — we just log and proceed with @@ -125,9 +250,28 @@ def _make_hermes_provider_class() -> Optional[type]: self._hermes_server_name, exc, ) - # Delegate to the SDK's auth flow - async for item in super().async_auth_flow(request): - yield item + # Manually bridge the bidirectional generator protocol. httpx's + # auth_flow driver (httpx._client._send_handling_auth) calls + # ``auth_flow.asend(response)`` to feed HTTP responses back into + # the generator. A naive wrapper using ``async for item in inner: + # yield item`` DISCARDS those .asend(response) values and resumes + # the inner generator with None, so the SDK's + # ``response = yield request`` branch in + # mcp/client/auth/oauth2.py sees response=None and crashes at + # ``if response.status_code == 401`` with AttributeError. + # + # The bridge below forwards each .asend() value into the inner + # generator via inner.asend(incoming), preserving the bidirectional + # contract. Regression from PR #11383 caught by + # tests/tools/test_mcp_oauth_bidirectional.py. + inner = super().async_auth_flow(request) + try: + outgoing = await inner.__anext__() + while True: + incoming = yield outgoing + outgoing = await inner.asend(incoming) + except StopAsyncIteration: + return return HermesMCPOAuthProvider