diff --git a/scripts/release.py b/scripts/release.py index 3f7930e77..ca41ef93c 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -132,6 +132,7 @@ AUTHOR_MAP = { "bryan@intertwinesys.com": "bryanyoung", "christo.mitov@gmail.com": "christomitov", "hermes@nousresearch.com": "NousResearch", + "hermes@noushq.ai": "benbarclay", "chinmingcock@gmail.com": "ChimingLiu", "openclaw@sparklab.ai": "openclaw", "semihcvlk53@gmail.com": "Himess", diff --git a/tests/tools/test_mcp_oauth_bidirectional.py b/tests/tools/test_mcp_oauth_bidirectional.py new file mode 100644 index 000000000..37ca409bb --- /dev/null +++ b/tests/tools/test_mcp_oauth_bidirectional.py @@ -0,0 +1,210 @@ +"""Regression test for the ``HermesMCPOAuthProvider.async_auth_flow`` bidirectional +generator bridge. + +PR #11383 introduced a subclass method that wrapped the SDK's ``auth_flow`` with:: + + async for item in super().async_auth_flow(request): + yield item + +``httpx``'s auth_flow contract is a **bidirectional** async generator — the +driving code (``httpx._client._send_handling_auth``) does:: + + next_request = await auth_flow.asend(response) + +to feed HTTP responses back into the generator. The naive ``async for ...`` +wrapper discards those ``.asend(response)`` values and resumes the inner +generator with ``None``, so the SDK's ``response = yield request`` branch in +``mcp/client/auth/oauth2.py`` sees ``response = None`` and crashes at +``if response.status_code == 401`` with +``AttributeError: 'NoneType' object has no attribute 'status_code'``. + +This broke every OAuth MCP server on the first HTTP response regardless of +status code. The reason nothing caught it in CI: zero existing tests drive +the full ``.asend()`` round-trip — the integration tests in +``test_mcp_oauth_integration.py`` stop at ``_initialize()`` and disk-watching. + +These tests drive the wrapper through a manual ``.asend()`` sequence to prove +the bridge forwards responses correctly into the inner SDK generator. +""" +from __future__ import annotations + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required") + + +@pytest.mark.asyncio +async def test_hermes_provider_forwards_asend_values(tmp_path, monkeypatch): + """The wrapper MUST forward ``.asend(response)`` into the inner generator. + + This is the primary regression test. With the broken wrapper, the inner + SDK generator sees ``response = None`` and raises ``AttributeError`` at + ``oauth2.py:505``. With the correct bridge, a 200 response finishes the + flow cleanly (``StopAsyncIteration``). + """ + import httpx + from mcp.shared.auth import OAuthClientMetadata, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None, "SDK OAuth types must be available" + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + reset_manager_for_tests() + + # Seed a valid-looking token so the SDK's _initialize loads something and + # can_refresh_token() is True (though we don't exercise refresh here — we + # go straight through the 200 path). + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="old_access", + token_type="Bearer", + expires_in=3600, + refresh_token="old_refresh", + ) + ) + # Also seed client_info so the SDK doesn't attempt registration. + from mcp.shared.auth import OAuthClientInformationFull + + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + req = httpx.Request("POST", "https://example.com/mcp") + flow = provider.async_auth_flow(req) + + # First anext() drives the wrapper + inner generator until the inner + # yields the outbound request (at oauth2.py:503 ``response = yield request``). + outbound = await flow.__anext__() + assert outbound is not None, "wrapper must yield the outbound request" + assert outbound.url.host == "example.com" + + # Simulate httpx returning a 200 response. + fake_response = httpx.Response(200, request=outbound) + + # The broken wrapper would crash here with AttributeError: 'NoneType' + # object has no attribute 'status_code', because the SDK's inner generator + # resumes with response=None and dereferences .status_code at line 505. + # + # The correct wrapper forwards the response, the SDK takes the non-401 + # non-403 exit, and the generator ends cleanly (StopAsyncIteration). + with pytest.raises(StopAsyncIteration): + await flow.asend(fake_response) + + +@pytest.mark.asyncio +async def test_hermes_provider_forwards_401_triggers_refresh(tmp_path, monkeypatch): + """A 401 response MUST flow into the inner generator and trigger the + SDK's 401 recovery branch. + + With the broken wrapper, the inner generator sees ``response = None`` + and the 401 check short-circuits into AttributeError. With the correct + bridge, the 401 is routed into the SDK's ``response.status_code == 401`` + branch which begins discovery (yielding a metadata-discovery request). + """ + import httpx + from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="old_access", + token_type="Bearer", + expires_in=3600, + refresh_token="old_refresh", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + req = httpx.Request("POST", "https://example.com/mcp") + flow = provider.async_auth_flow(req) + + # Drive to the first yield (outbound MCP request). + outbound = await flow.__anext__() + + # Reply with a 401 including a minimal WWW-Authenticate so the SDK's + # 401 branch can parse resource metadata from it. We just need something + # the SDK accepts before it tries to yield the metadata-discovery request. + fake_401 = httpx.Response( + 401, + request=outbound, + headers={"www-authenticate": 'Bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource"'}, + ) + + # The correct bridge forwards the 401 into the SDK; the SDK then yields + # its NEXT request (a metadata-discovery GET). We assert we get a request + # back — any request. The broken bridge would have crashed with + # AttributeError before we ever reach this point. + next_request = await flow.asend(fake_401) + assert isinstance(next_request, httpx.Request), ( + "wrapper must forward .asend() so the SDK's 401 branch can yield the " + "next request in the discovery flow" + ) + + # Clean up the generator — we don't need to complete the full dance. + await flow.aclose() + + +async def _noop_redirect(_url: str) -> None: + """Redirect handler that does nothing (won't be invoked in these tests).""" + return None + + +async def _noop_callback() -> tuple[str, str | None]: + """Callback handler that won't be invoked in these tests.""" + raise AssertionError( + "callback handler should not be invoked in bidirectional-generator tests" + ) diff --git a/tests/tools/test_mcp_oauth_cold_load_expiry.py b/tests/tools/test_mcp_oauth_cold_load_expiry.py new file mode 100644 index 000000000..a9fb19106 --- /dev/null +++ b/tests/tools/test_mcp_oauth_cold_load_expiry.py @@ -0,0 +1,546 @@ +"""Tests for cold-load token expiry tracking in MCP OAuth. + +PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch) +and 401 dedup, but left two underlying latent bugs in place: + +1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``, + which is meaningless after a process restart. +2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from + storage but does NOT call ``update_token_expiry``, so + ``token_expiry_time`` stays None. ``is_token_valid()`` then returns True + for any loaded token regardless of actual age, and the SDK's preemptive + refresh branch at ``oauth2.py:491`` is never taken. + +Consequence: a token that expired while the process was down ships to the +server with a stale Bearer header. The server's response is provider-specific +— some return HTTP 401 (caught by the consolidation's 401 handler, which +surfaces a ``needs_reauth`` error), others return HTTP 200 with an +application-level auth failure in the body (e.g. BetterStack's "No teams +found. Please check your authentication."), which the consolidation cannot +detect. + +These tests pin the contract for Fix A: +- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp. +- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so + the SDK's ``update_token_expiry`` computes the correct absolute expiry. +- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time`` + after loading, so ``is_token_valid()`` reports True only for tokens that + are actually still valid, and the SDK's preemptive refresh fires for + expired tokens with a live refresh_token. + +Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute +timestamp persisted alongside the access_token (``auth.ts:~180``). +""" +from __future__ import annotations + +import asyncio +import json +import time + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required") + + +# --------------------------------------------------------------------------- +# HermesTokenStorage — absolute expiry persistence +# --------------------------------------------------------------------------- + + +class TestSetTokensAbsoluteExpiry: + def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch): + """Tokens round-tripped through disk must encode absolute expiry.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + before = time.time() + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + ) + after = time.time() + + on_disk = json.loads( + (tmp_path / "mcp-tokens" / "srv.json").read_text() + ) + assert "expires_at" in on_disk, ( + "Fix A: set_tokens must record an absolute expires_at wall-clock " + "timestamp alongside the SDK's serialized token so cold-loads " + "can compute correct remaining TTL." + ) + assert before + 3600 <= on_disk["expires_at"] <= after + 3600 + + def test_set_tokens_without_expires_in_omits_expires_at( + self, tmp_path, monkeypatch + ): + """Tokens without a TTL must not gain a fabricated expires_at.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + refresh_token="r", + ) + ) + ) + + on_disk = json.loads( + (tmp_path / "mcp-tokens" / "srv.json").read_text() + ) + assert "expires_at" not in on_disk + + +class TestGetTokensReconstructsExpiresIn: + def test_get_tokens_uses_expires_at_for_remaining_ttl( + self, tmp_path, monkeypatch + ): + """Round-trip: expires_in on read must reflect time remaining.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthToken + + from tools.mcp_oauth import HermesTokenStorage + + storage = HermesTokenStorage("srv") + asyncio.run( + storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + ) + + # Wait briefly so the remaining TTL is measurably less than 3600. + time.sleep(0.05) + + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in is not None + # Should be slightly less than 3600 after the 50ms sleep. + assert 3500 < reloaded.expires_in <= 3600 + + def test_get_tokens_returns_zero_ttl_for_expired_token( + self, tmp_path, monkeypatch + ): + """An already-expired token reloaded from disk must report expires_in=0.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + # Write an already-expired token file directly. + (token_dir / "srv.json").write_text( + json.dumps( + { + "access_token": "a", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": time.time() - 60, # expired 1 min ago + "refresh_token": "r", + } + ) + ) + + storage = HermesTokenStorage("srv") + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in == 0, ( + "Expired token must reload with expires_in=0 so the SDK's " + "is_token_valid() returns False and preemptive refresh fires." + ) + + def test_get_tokens_legacy_file_without_expires_at_is_loadable( + self, tmp_path, monkeypatch + ): + """Existing on-disk files (pre-Fix-A) must still load without crashing. + + Pre-existing token files have ``expires_in`` but no ``expires_at``. + Fix A falls back to the file's mtime as a best-effort wall-clock + proxy: a file whose (mtime + expires_in) is in the past clamps + expires_in to zero so the SDK refreshes on next request. A fresh + legacy-format file (mtime = now) keeps most of its TTL. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + # Legacy-shape file (no expires_at). Make it stale by backdating mtime + # well past its nominal expires_in. + legacy_path = token_dir / "srv.json" + legacy_path.write_text( + json.dumps( + { + "access_token": "a", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "r", + } + ) + ) + stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL + import os + + os.utime(legacy_path, (stale_time, stale_time)) + + storage = HermesTokenStorage("srv") + reloaded = asyncio.run(storage.get_tokens()) + assert reloaded is not None + assert reloaded.expires_in == 0, ( + "Legacy file whose mtime + expires_in is in the past must report " + "expires_in=0 so the SDK refreshes on next request." + ) + + +# --------------------------------------------------------------------------- +# HermesMCPOAuthProvider._initialize — seed token_expiry_time +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_seeds_token_expiry_time_from_stored_tokens( + tmp_path, monkeypatch +): + """Cold-load must populate context.token_expiry_time. + + The SDK's base ``_initialize`` loads current_tokens but doesn't seed + token_expiry_time. Our subclass must do it so ``is_token_valid()`` + reports correctly and the preemptive-refresh path fires when needed. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=7200, + refresh_token="r", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + from mcp.shared.auth import OAuthClientMetadata + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.token_expiry_time is not None, ( + "Fix A: _initialize must seed context.token_expiry_time so " + "is_token_valid() correctly reports expiry on cold-load." + ) + # Should be ~7200s in the future (fresh write). + assert provider.context.token_expiry_time > time.time() + 7000 + assert provider.context.token_expiry_time <= time.time() + 7200 + 5 + + +@pytest.mark.asyncio +async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch): + """After _initialize, an expired-on-disk token must report is_token_valid=False. + + This is the end-to-end assertion: cold-load an expired token, verify the + SDK's own ``is_token_valid()`` now returns False (the consequence of + seeding token_expiry_time correctly), so the SDK's ``async_auth_flow`` + will take the ``can_refresh_token()`` branch on the next request and + silently refresh instead of sending the stale Bearer. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage, _get_token_dir + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + # Write an already-expired token directly so we control the wall-clock. + token_dir = _get_token_dir() + token_dir.mkdir(parents=True, exist_ok=True) + (token_dir / "srv.json").write_text( + json.dumps( + { + "access_token": "stale", + "token_type": "Bearer", + "expires_in": 3600, + "expires_at": time.time() - 60, + "refresh_token": "fresh", + } + ) + ) + + storage = HermesTokenStorage("srv") + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://example.com/mcp", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.is_token_valid() is False, ( + "After _initialize with an expired-on-disk token, is_token_valid() " + "must return False so the SDK's async_auth_flow takes the " + "preemptive refresh path." + ) + assert provider.context.can_refresh_token() is True, ( + "Refresh should remain possible because refresh_token + client_info " + "are both present." + ) + + +async def _noop_redirect(_url: str) -> None: + return None + + +async def _noop_callback() -> tuple[str, str | None]: + raise AssertionError("callback handler should not be invoked in these tests") + + +# --------------------------------------------------------------------------- +# Pre-flight OAuth metadata discovery +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_initialize_prefetches_oauth_metadata_when_missing( + tmp_path, monkeypatch +): + """Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token`` + has the correct ``token_endpoint`` before the first refresh attempt. + + Without this, the SDK's ``_refresh_token`` falls back to + ``{server_url}/token`` which is wrong for providers whose AS is at + a different origin. BetterStack specifically: MCP at + ``mcp.betterstack.com`` but token_endpoint at + ``betterstack.com/oauth/token``. Without pre-flight the refresh 404s + and we drop into full browser re-auth — visible to the user as an + unwanted OAuth browser prompt every time the process restarts. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + import httpx + from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthToken, + ) + from pydantic import AnyUrl + + from tools.mcp_oauth import HermesTokenStorage + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + storage = HermesTokenStorage("srv") + await storage.set_tokens( + OAuthToken( + access_token="a", + token_type="Bearer", + expires_in=3600, + refresh_token="r", + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test-client", + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="none", + ) + ) + + # Route the AsyncClient used inside _prefetch_oauth_metadata through a + # MockTransport that mimics BetterStack's split-origin discovery: + # PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com + # ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token + def mock_handler(request: httpx.Request) -> httpx.Response: + url = str(request.url) + if url.endswith("/.well-known/oauth-protected-resource"): + return httpx.Response( + 200, + json={ + "resource": "https://mcp.example.com", + "authorization_servers": ["https://auth.example.com"], + "scopes_supported": ["read", "write"], + "bearer_methods_supported": ["header"], + }, + ) + if url.endswith("/.well-known/oauth-authorization-server"): + return httpx.Response( + 200, + json={ + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/oauth/authorize", + "token_endpoint": "https://auth.example.com/oauth/token", + "registration_endpoint": "https://auth.example.com/oauth/register", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": ["none"], + "scopes_supported": ["read", "write"], + }, + ) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + + # Patch the AsyncClient constructor used by _prefetch_oauth_metadata so + # it uses our mock transport instead of the real network. + import httpx as real_httpx + + original_async_client = real_httpx.AsyncClient + + def patched_async_client(*args, **kwargs): + kwargs["transport"] = transport + return original_async_client(*args, **kwargs) + + monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client) + + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://mcp.example.com", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert provider.context.protected_resource_metadata is not None, ( + "Pre-flight must cache PRM for the SDK to reference later." + ) + assert provider.context.oauth_metadata is not None, ( + "Pre-flight must cache ASM so _refresh_token builds the correct " + "token_endpoint URL." + ) + assert str(provider.context.oauth_metadata.token_endpoint) == ( + "https://auth.example.com/oauth/token" + ) + + +@pytest.mark.asyncio +async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch): + """Pre-flight must not run when there are no stored tokens yet. + + Without this guard, every fresh-install ``_initialize`` would do two + extra network roundtrips that gain nothing (the SDK's 401-branch + discovery will run on the first real request anyway). + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + import httpx + from mcp.shared.auth import OAuthClientMetadata + from pydantic import AnyUrl + + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests + from tools.mcp_oauth import HermesTokenStorage + + assert _HERMES_PROVIDER_CLS is not None + reset_manager_for_tests() + + calls: list[str] = [] + + def mock_handler(request: httpx.Request) -> httpx.Response: + calls.append(str(request.url)) + return httpx.Response(404) + + transport = httpx.MockTransport(mock_handler) + import httpx as real_httpx + + original = real_httpx.AsyncClient + + def patched(*args, **kwargs): + kwargs["transport"] = transport + return original(*args, **kwargs) + + monkeypatch.setattr(real_httpx, "AsyncClient", patched) + + storage = HermesTokenStorage("srv") # empty — no tokens on disk + metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")], + client_name="Hermes Agent", + ) + provider = _HERMES_PROVIDER_CLS( + server_name="srv", + server_url="https://mcp.example.com", + client_metadata=metadata, + storage=storage, + redirect_handler=_noop_redirect, + callback_handler=_noop_callback, + ) + + await provider._initialize() + + assert calls == [], ( + f"Pre-flight must not fire when no tokens are stored, but got {calls}" + ) diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index 6e1d7f5fb..7910c3cdc 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -40,6 +40,7 @@ import re import socket import sys import threading +import time import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path @@ -196,6 +197,35 @@ class HermesTokenStorage: data = _read_json(self._tokens_path()) if data is None: return None + # Hermes records an absolute wall-clock ``expires_at`` alongside the + # SDK's serialized token (see ``set_tokens``). On read we rewrite + # ``expires_in`` to the remaining seconds so the SDK's downstream + # ``update_token_expiry`` computes the correct absolute time and + # ``is_token_valid()`` correctly reports False for tokens that + # expired while the process was down. + # + # Legacy token files (pre-Fix-A) have ``expires_in`` but no + # ``expires_at``. We fall back to the file's mtime as a best-effort + # wall-clock proxy for when the token was written: if (mtime + + # expires_in) is in the past, clamp ``expires_in`` to zero so the + # SDK refreshes before the first request. This self-heals one-time + # on the next successful ``set_tokens``, which writes the new + # ``expires_at`` field. The stored ``expires_at`` is stripped before + # model_validate because it's not part of the SDK's OAuthToken schema. + absolute_expiry = data.pop("expires_at", None) + if absolute_expiry is not None: + data["expires_in"] = int(max(absolute_expiry - time.time(), 0)) + elif data.get("expires_in") is not None: + try: + file_mtime = self._tokens_path().stat().st_mtime + except OSError: + file_mtime = None + if file_mtime is not None: + try: + implied_expiry = file_mtime + int(data["expires_in"]) + data["expires_in"] = int(max(implied_expiry - time.time(), 0)) + except (TypeError, ValueError): + pass try: return OAuthToken.model_validate(data) except (ValueError, TypeError, KeyError) as exc: @@ -203,7 +233,23 @@ class HermesTokenStorage: return None async def set_tokens(self, tokens: "OAuthToken") -> None: - _write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) + payload = tokens.model_dump(exclude_none=True) + # Persist an absolute ``expires_at`` so a process restart can + # reconstruct the correct remaining TTL. Without this the MCP SDK's + # ``_initialize`` reloads a relative ``expires_in`` which has no + # wall-clock reference, leaving ``context.token_expiry_time=None`` + # and ``is_token_valid()`` falsely reporting True. See Fix A in + # ``mcp-oauth-token-diagnosis`` skill + Claude Code's + # ``OAuthTokens.expiresAt`` persistence (auth.ts ~180). + expires_in = payload.get("expires_in") + if expires_in is not None: + try: + payload["expires_at"] = time.time() + int(expires_in) + except (TypeError, ValueError): + # Mock tokens or unusual shapes: skip the expires_at write + # rather than fail persistence. + pass + _write_json(self._tokens_path(), payload) logger.debug("OAuth tokens saved for %s", self._server_name) # -- client info ------------------------------------------------------- diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py index d3760e3b8..7c8a91f3f 100644 --- a/tools/mcp_oauth_manager.py +++ b/tools/mcp_oauth_manager.py @@ -111,6 +111,131 @@ def _make_hermes_provider_class() -> Optional[type]: super().__init__(*args, **kwargs) self._hermes_server_name = server_name + async def _initialize(self) -> None: + """Load stored tokens + client info AND seed token_expiry_time. + + Also eagerly fetches OAuth authorization-server metadata (PRM + + ASM) when we have stored tokens but no cached metadata, so the + SDK's ``_refresh_token`` can build the correct token_endpoint + URL on the preemptive-refresh path. Without this, the SDK + falls back to ``{mcp_server_url}/token`` (wrong for providers + whose AS is a different origin — BetterStack's MCP lives at + ``https://mcp.betterstack.com`` but its token endpoint is at + ``https://betterstack.com/oauth/token``), the refresh 404s, and + we drop through to full browser reauth. + + The SDK's base ``_initialize`` populates ``current_tokens`` but + does NOT call ``update_token_expiry``, so ``token_expiry_time`` + stays ``None`` and ``is_token_valid()`` returns True for any + loaded token regardless of actual age. After a process restart + this ships stale Bearer tokens to the server; some providers + return HTTP 401 (caught by the 401 handler), others return 200 + with an app-level auth error (invisible to the transport layer, + e.g. BetterStack returning "No teams found. Please check your + authentication."). + + Seeding ``token_expiry_time`` from the reloaded token fixes that: + ``is_token_valid()`` correctly reports False for expired tokens, + ``async_auth_flow`` takes the ``can_refresh_token()`` branch, + and the SDK quietly refreshes before the first real request. + + Paired with :class:`HermesTokenStorage` persisting an absolute + ``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the + remaining TTL we compute here reflects real wall-clock age. + """ + await super()._initialize() + tokens = self.context.current_tokens + if tokens is not None and tokens.expires_in is not None: + self.context.update_token_expiry(tokens) + + # Pre-flight OAuth AS discovery so ``_refresh_token`` has a + # correct ``token_endpoint`` before the first refresh attempt. + # Only runs when we have tokens on cold-load but no cached + # metadata — i.e. the exact scenario where the SDK's built-in + # 401-branch discovery hasn't had a chance to run yet. + if ( + tokens is not None + and self.context.oauth_metadata is None + ): + try: + await self._prefetch_oauth_metadata() + except Exception as exc: # pragma: no cover — defensive + # Non-fatal: if discovery fails, the SDK's normal 401- + # branch discovery will run on the next request. + logger.debug( + "MCP OAuth '%s': pre-flight metadata discovery " + "failed (non-fatal): %s", + self._hermes_server_name, exc, + ) + + async def _prefetch_oauth_metadata(self) -> None: + """Fetch PRM + ASM from the well-known endpoints, cache on context. + + Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551) + but runs synchronously before the first request instead of + inside the httpx auth_flow generator. Uses the SDK's own URL + builders and response handlers so we track whatever the SDK + version we're pinned to expects. + """ + import httpx # local import: httpx is an MCP SDK dependency + from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + create_oauth_metadata_request, + handle_auth_metadata_response, + handle_protected_resource_response, + ) + + server_url = self.context.server_url + async with httpx.AsyncClient(timeout=10.0) as client: + # Step 1: PRM discovery to learn the authorization_server URL. + for url in build_protected_resource_metadata_discovery_urls( + None, server_url + ): + req = create_oauth_metadata_request(url) + try: + resp = await client.send(req) + except httpx.HTTPError as exc: + logger.debug( + "MCP OAuth '%s': PRM discovery to %s failed: %s", + self._hermes_server_name, url, exc, + ) + continue + prm = await handle_protected_resource_response(resp) + if prm: + self.context.protected_resource_metadata = prm + if prm.authorization_servers: + self.context.auth_server_url = str( + prm.authorization_servers[0] + ) + break + + # Step 2: ASM discovery against the auth_server_url (or + # server_url fallback for legacy providers). + for url in build_oauth_authorization_server_metadata_discovery_urls( + self.context.auth_server_url, server_url + ): + req = create_oauth_metadata_request(url) + try: + resp = await client.send(req) + except httpx.HTTPError as exc: + logger.debug( + "MCP OAuth '%s': ASM discovery to %s failed: %s", + self._hermes_server_name, url, exc, + ) + continue + ok, asm = await handle_auth_metadata_response(resp) + if not ok: + break + if asm: + self.context.oauth_metadata = asm + logger.debug( + "MCP OAuth '%s': pre-flight ASM discovered " + "token_endpoint=%s", + self._hermes_server_name, asm.token_endpoint, + ) + break + async def async_auth_flow(self, request): # type: ignore[override] # Pre-flow hook: ask the manager to refresh from disk if needed. # Any failure here is non-fatal — we just log and proceed with @@ -125,9 +250,28 @@ def _make_hermes_provider_class() -> Optional[type]: self._hermes_server_name, exc, ) - # Delegate to the SDK's auth flow - async for item in super().async_auth_flow(request): - yield item + # Manually bridge the bidirectional generator protocol. httpx's + # auth_flow driver (httpx._client._send_handling_auth) calls + # ``auth_flow.asend(response)`` to feed HTTP responses back into + # the generator. A naive wrapper using ``async for item in inner: + # yield item`` DISCARDS those .asend(response) values and resumes + # the inner generator with None, so the SDK's + # ``response = yield request`` branch in + # mcp/client/auth/oauth2.py sees response=None and crashes at + # ``if response.status_code == 401`` with AttributeError. + # + # The bridge below forwards each .asend() value into the inner + # generator via inner.asend(incoming), preserving the bidirectional + # contract. Regression from PR #11383 caught by + # tests/tools/test_mcp_oauth_bidirectional.py. + inner = super().async_auth_flow(request) + try: + outgoing = await inner.__anext__() + while True: + incoming = yield outgoing + outgoing = await inner.asend(incoming) + except StopAsyncIteration: + return return HermesMCPOAuthProvider