[verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load

PR #11383's consolidation fixed external-refresh reloading and 401 dedup
but left two latent bugs that surfaced on BetterStack and any other OAuth
MCP with a split-origin authorization server:

1. HermesTokenStorage persisted only a relative 'expires_in', which is
   meaningless after a process restart. The MCP SDK's OAuthContext
   does NOT seed token_expiry_time in _initialize, so is_token_valid()
   returned True for any reloaded token regardless of age. Expired
   tokens shipped to servers, and app-level auth failures (e.g.
   BetterStack's 'No teams found. Please check your authentication.')
   were invisible to the transport-layer 401 handler.

2. Even once preemptive refresh did fire, the SDK's _refresh_token
   falls back to {server_url}/token when oauth_metadata isn't cached.
   For providers whose AS is at a different origin (BetterStack:
   mcp.betterstack.com for MCP, betterstack.com/oauth/token for the
   token endpoint), that fallback 404s and drops into full browser
   re-auth on every process restart.

Fix set:

- HermesTokenStorage.set_tokens persists an absolute wall-clock
  expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL
  at write time).
- HermesTokenStorage.get_tokens reconstructs expires_in from
  max(expires_at - now, 0), clamping expired tokens to zero TTL.
  Legacy files without expires_at fall back to file-mtime as a
  best-effort wall-clock proxy, self-healing on the next set_tokens.
- HermesMCPOAuthProvider._initialize calls super(), then
  update_token_expiry on the reloaded tokens so token_expiry_time
  reflects actual remaining TTL. If tokens are loaded but
  oauth_metadata is missing, pre-flight PRM + ASM discovery runs
  via httpx.AsyncClient using the MCP SDK's own URL builders and
  response handlers (build_protected_resource_metadata_discovery_urls,
  handle_auth_metadata_response, etc.) so the SDK sees the correct
  token_endpoint before the first refresh attempt. Pre-flight is
  skipped when there are no stored tokens to keep fresh-install
  paths zero-cost.

Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py):
- set_tokens persists absolute expires_at
- set_tokens skips expires_at when token has no expires_in
- get_tokens round-trips expires_at -> remaining expires_in
- expired tokens reload with expires_in=0
- legacy files without expires_at fall back to mtime proxy
- _initialize seeds token_expiry_time from stored tokens
- _initialize flags expired-on-disk tokens as is_token_valid=False
- _initialize pre-flights PRM + ASM discovery with mock transport
- _initialize skips pre-flight when no tokens are stored

Verified against BetterStack MCP:
  hermes mcp test betterstack -> Connected (2508ms), 83 tools
  mcp_betterstack_telemetry_list_teams_tool -> real team data, not
    'No teams found. Please check your authentication.'

Reference: mcp-oauth-token-diagnosis skill, Fix A.
This commit is contained in:
Hermes Agent 2026-04-18 14:57:38 +10:00
parent 3eeab4bc06
commit cc06beaf13
3 changed files with 718 additions and 1 deletions

View file

@ -0,0 +1,546 @@
"""Tests for cold-load token expiry tracking in MCP OAuth.
PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch)
and 401 dedup, but left two underlying latent bugs in place:
1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``,
which is meaningless after a process restart.
2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from
storage but does NOT call ``update_token_expiry``, so
``token_expiry_time`` stays None. ``is_token_valid()`` then returns True
for any loaded token regardless of actual age, and the SDK's preemptive
refresh branch at ``oauth2.py:491`` is never taken.
Consequence: a token that expired while the process was down ships to the
server with a stale Bearer header. The server's response is provider-specific
some return HTTP 401 (caught by the consolidation's 401 handler, which
surfaces a ``needs_reauth`` error), others return HTTP 200 with an
application-level auth failure in the body (e.g. BetterStack's "No teams
found. Please check your authentication."), which the consolidation cannot
detect.
These tests pin the contract for Fix A:
- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp.
- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so
the SDK's ``update_token_expiry`` computes the correct absolute expiry.
- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time``
after loading, so ``is_token_valid()`` reports True only for tokens that
are actually still valid, and the SDK's preemptive refresh fires for
expired tokens with a live refresh_token.
Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute
timestamp persisted alongside the access_token (``auth.ts:~180``).
"""
from __future__ import annotations
import asyncio
import json
import time
import pytest
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
# ---------------------------------------------------------------------------
# HermesTokenStorage — absolute expiry persistence
# ---------------------------------------------------------------------------
class TestSetTokensAbsoluteExpiry:
def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch):
"""Tokens round-tripped through disk must encode absolute expiry."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
before = time.time()
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
)
after = time.time()
on_disk = json.loads(
(tmp_path / "mcp-tokens" / "srv.json").read_text()
)
assert "expires_at" in on_disk, (
"Fix A: set_tokens must record an absolute expires_at wall-clock "
"timestamp alongside the SDK's serialized token so cold-loads "
"can compute correct remaining TTL."
)
assert before + 3600 <= on_disk["expires_at"] <= after + 3600
def test_set_tokens_without_expires_in_omits_expires_at(
self, tmp_path, monkeypatch
):
"""Tokens without a TTL must not gain a fabricated expires_at."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
refresh_token="r",
)
)
)
on_disk = json.loads(
(tmp_path / "mcp-tokens" / "srv.json").read_text()
)
assert "expires_at" not in on_disk
class TestGetTokensReconstructsExpiresIn:
def test_get_tokens_uses_expires_at_for_remaining_ttl(
self, tmp_path, monkeypatch
):
"""Round-trip: expires_in on read must reflect time remaining."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
)
# Wait briefly so the remaining TTL is measurably less than 3600.
time.sleep(0.05)
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in is not None
# Should be slightly less than 3600 after the 50ms sleep.
assert 3500 < reloaded.expires_in <= 3600
def test_get_tokens_returns_zero_ttl_for_expired_token(
self, tmp_path, monkeypatch
):
"""An already-expired token reloaded from disk must report expires_in=0."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
# Write an already-expired token file directly.
(token_dir / "srv.json").write_text(
json.dumps(
{
"access_token": "a",
"token_type": "Bearer",
"expires_in": 3600,
"expires_at": time.time() - 60, # expired 1 min ago
"refresh_token": "r",
}
)
)
storage = HermesTokenStorage("srv")
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in == 0, (
"Expired token must reload with expires_in=0 so the SDK's "
"is_token_valid() returns False and preemptive refresh fires."
)
def test_get_tokens_legacy_file_without_expires_at_is_loadable(
self, tmp_path, monkeypatch
):
"""Existing on-disk files (pre-Fix-A) must still load without crashing.
Pre-existing token files have ``expires_in`` but no ``expires_at``.
Fix A falls back to the file's mtime as a best-effort wall-clock
proxy: a file whose (mtime + expires_in) is in the past clamps
expires_in to zero so the SDK refreshes on next request. A fresh
legacy-format file (mtime = now) keeps most of its TTL.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
# Legacy-shape file (no expires_at). Make it stale by backdating mtime
# well past its nominal expires_in.
legacy_path = token_dir / "srv.json"
legacy_path.write_text(
json.dumps(
{
"access_token": "a",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "r",
}
)
)
stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL
import os
os.utime(legacy_path, (stale_time, stale_time))
storage = HermesTokenStorage("srv")
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in == 0, (
"Legacy file whose mtime + expires_in is in the past must report "
"expires_in=0 so the SDK refreshes on next request."
)
# ---------------------------------------------------------------------------
# HermesMCPOAuthProvider._initialize — seed token_expiry_time
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_initialize_seeds_token_expiry_time_from_stored_tokens(
tmp_path, monkeypatch
):
"""Cold-load must populate context.token_expiry_time.
The SDK's base ``_initialize`` loads current_tokens but doesn't seed
token_expiry_time. Our subclass must do it so ``is_token_valid()``
reports correctly and the preemptive-refresh path fires when needed.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
storage = HermesTokenStorage("srv")
await storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=7200,
refresh_token="r",
)
)
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
from mcp.shared.auth import OAuthClientMetadata
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://example.com/mcp",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.token_expiry_time is not None, (
"Fix A: _initialize must seed context.token_expiry_time so "
"is_token_valid() correctly reports expiry on cold-load."
)
# Should be ~7200s in the future (fresh write).
assert provider.context.token_expiry_time > time.time() + 7000
assert provider.context.token_expiry_time <= time.time() + 7200 + 5
@pytest.mark.asyncio
async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch):
"""After _initialize, an expired-on-disk token must report is_token_valid=False.
This is the end-to-end assertion: cold-load an expired token, verify the
SDK's own ``is_token_valid()`` now returns False (the consequence of
seeding token_expiry_time correctly), so the SDK's ``async_auth_flow``
will take the ``can_refresh_token()`` branch on the next request and
silently refresh instead of sending the stale Bearer.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
# Write an already-expired token directly so we control the wall-clock.
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
(token_dir / "srv.json").write_text(
json.dumps(
{
"access_token": "stale",
"token_type": "Bearer",
"expires_in": 3600,
"expires_at": time.time() - 60,
"refresh_token": "fresh",
}
)
)
storage = HermesTokenStorage("srv")
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://example.com/mcp",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.is_token_valid() is False, (
"After _initialize with an expired-on-disk token, is_token_valid() "
"must return False so the SDK's async_auth_flow takes the "
"preemptive refresh path."
)
assert provider.context.can_refresh_token() is True, (
"Refresh should remain possible because refresh_token + client_info "
"are both present."
)
async def _noop_redirect(_url: str) -> None:
return None
async def _noop_callback() -> tuple[str, str | None]:
raise AssertionError("callback handler should not be invoked in these tests")
# ---------------------------------------------------------------------------
# Pre-flight OAuth metadata discovery
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_initialize_prefetches_oauth_metadata_when_missing(
tmp_path, monkeypatch
):
"""Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token``
has the correct ``token_endpoint`` before the first refresh attempt.
Without this, the SDK's ``_refresh_token`` falls back to
``{server_url}/token`` which is wrong for providers whose AS is at
a different origin. BetterStack specifically: MCP at
``mcp.betterstack.com`` but token_endpoint at
``betterstack.com/oauth/token``. Without pre-flight the refresh 404s
and we drop into full browser re-auth visible to the user as an
unwanted OAuth browser prompt every time the process restarts.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import httpx
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthToken,
)
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
storage = HermesTokenStorage("srv")
await storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
# Route the AsyncClient used inside _prefetch_oauth_metadata through a
# MockTransport that mimics BetterStack's split-origin discovery:
# PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com
# ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token
def mock_handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if url.endswith("/.well-known/oauth-protected-resource"):
return httpx.Response(
200,
json={
"resource": "https://mcp.example.com",
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read", "write"],
"bearer_methods_supported": ["header"],
},
)
if url.endswith("/.well-known/oauth-authorization-server"):
return httpx.Response(
200,
json={
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/oauth/authorize",
"token_endpoint": "https://auth.example.com/oauth/token",
"registration_endpoint": "https://auth.example.com/oauth/register",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": ["none"],
"scopes_supported": ["read", "write"],
},
)
return httpx.Response(404)
transport = httpx.MockTransport(mock_handler)
# Patch the AsyncClient constructor used by _prefetch_oauth_metadata so
# it uses our mock transport instead of the real network.
import httpx as real_httpx
original_async_client = real_httpx.AsyncClient
def patched_async_client(*args, **kwargs):
kwargs["transport"] = transport
return original_async_client(*args, **kwargs)
monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client)
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://mcp.example.com",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.protected_resource_metadata is not None, (
"Pre-flight must cache PRM for the SDK to reference later."
)
assert provider.context.oauth_metadata is not None, (
"Pre-flight must cache ASM so _refresh_token builds the correct "
"token_endpoint URL."
)
assert str(provider.context.oauth_metadata.token_endpoint) == (
"https://auth.example.com/oauth/token"
)
@pytest.mark.asyncio
async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch):
"""Pre-flight must not run when there are no stored tokens yet.
Without this guard, every fresh-install ``_initialize`` would do two
extra network roundtrips that gain nothing (the SDK's 401-branch
discovery will run on the first real request anyway).
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import httpx
from mcp.shared.auth import OAuthClientMetadata
from pydantic import AnyUrl
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
from tools.mcp_oauth import HermesTokenStorage
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
calls: list[str] = []
def mock_handler(request: httpx.Request) -> httpx.Response:
calls.append(str(request.url))
return httpx.Response(404)
transport = httpx.MockTransport(mock_handler)
import httpx as real_httpx
original = real_httpx.AsyncClient
def patched(*args, **kwargs):
kwargs["transport"] = transport
return original(*args, **kwargs)
monkeypatch.setattr(real_httpx, "AsyncClient", patched)
storage = HermesTokenStorage("srv") # empty — no tokens on disk
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://mcp.example.com",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert calls == [], (
f"Pre-flight must not fire when no tokens are stored, but got {calls}"
)

View file

@ -40,6 +40,7 @@ import re
import socket import socket
import sys import sys
import threading import threading
import time
import webbrowser import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path from pathlib import Path
@ -196,6 +197,35 @@ class HermesTokenStorage:
data = _read_json(self._tokens_path()) data = _read_json(self._tokens_path())
if data is None: if data is None:
return None return None
# Hermes records an absolute wall-clock ``expires_at`` alongside the
# SDK's serialized token (see ``set_tokens``). On read we rewrite
# ``expires_in`` to the remaining seconds so the SDK's downstream
# ``update_token_expiry`` computes the correct absolute time and
# ``is_token_valid()`` correctly reports False for tokens that
# expired while the process was down.
#
# Legacy token files (pre-Fix-A) have ``expires_in`` but no
# ``expires_at``. We fall back to the file's mtime as a best-effort
# wall-clock proxy for when the token was written: if (mtime +
# expires_in) is in the past, clamp ``expires_in`` to zero so the
# SDK refreshes before the first request. This self-heals one-time
# on the next successful ``set_tokens``, which writes the new
# ``expires_at`` field. The stored ``expires_at`` is stripped before
# model_validate because it's not part of the SDK's OAuthToken schema.
absolute_expiry = data.pop("expires_at", None)
if absolute_expiry is not None:
data["expires_in"] = int(max(absolute_expiry - time.time(), 0))
elif data.get("expires_in") is not None:
try:
file_mtime = self._tokens_path().stat().st_mtime
except OSError:
file_mtime = None
if file_mtime is not None:
try:
implied_expiry = file_mtime + int(data["expires_in"])
data["expires_in"] = int(max(implied_expiry - time.time(), 0))
except (TypeError, ValueError):
pass
try: try:
return OAuthToken.model_validate(data) return OAuthToken.model_validate(data)
except (ValueError, TypeError, KeyError) as exc: except (ValueError, TypeError, KeyError) as exc:
@ -203,7 +233,23 @@ class HermesTokenStorage:
return None return None
async def set_tokens(self, tokens: "OAuthToken") -> None: async def set_tokens(self, tokens: "OAuthToken") -> None:
_write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) payload = tokens.model_dump(exclude_none=True)
# Persist an absolute ``expires_at`` so a process restart can
# reconstruct the correct remaining TTL. Without this the MCP SDK's
# ``_initialize`` reloads a relative ``expires_in`` which has no
# wall-clock reference, leaving ``context.token_expiry_time=None``
# and ``is_token_valid()`` falsely reporting True. See Fix A in
# ``mcp-oauth-token-diagnosis`` skill + Claude Code's
# ``OAuthTokens.expiresAt`` persistence (auth.ts ~180).
expires_in = payload.get("expires_in")
if expires_in is not None:
try:
payload["expires_at"] = time.time() + int(expires_in)
except (TypeError, ValueError):
# Mock tokens or unusual shapes: skip the expires_at write
# rather than fail persistence.
pass
_write_json(self._tokens_path(), payload)
logger.debug("OAuth tokens saved for %s", self._server_name) logger.debug("OAuth tokens saved for %s", self._server_name)
# -- client info ------------------------------------------------------- # -- client info -------------------------------------------------------

View file

@ -111,6 +111,131 @@ def _make_hermes_provider_class() -> Optional[type]:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hermes_server_name = server_name self._hermes_server_name = server_name
async def _initialize(self) -> None:
"""Load stored tokens + client info AND seed token_expiry_time.
Also eagerly fetches OAuth authorization-server metadata (PRM +
ASM) when we have stored tokens but no cached metadata, so the
SDK's ``_refresh_token`` can build the correct token_endpoint
URL on the preemptive-refresh path. Without this, the SDK
falls back to ``{mcp_server_url}/token`` (wrong for providers
whose AS is a different origin BetterStack's MCP lives at
``https://mcp.betterstack.com`` but its token endpoint is at
``https://betterstack.com/oauth/token``), the refresh 404s, and
we drop through to full browser reauth.
The SDK's base ``_initialize`` populates ``current_tokens`` but
does NOT call ``update_token_expiry``, so ``token_expiry_time``
stays ``None`` and ``is_token_valid()`` returns True for any
loaded token regardless of actual age. After a process restart
this ships stale Bearer tokens to the server; some providers
return HTTP 401 (caught by the 401 handler), others return 200
with an app-level auth error (invisible to the transport layer,
e.g. BetterStack returning "No teams found. Please check your
authentication.").
Seeding ``token_expiry_time`` from the reloaded token fixes that:
``is_token_valid()`` correctly reports False for expired tokens,
``async_auth_flow`` takes the ``can_refresh_token()`` branch,
and the SDK quietly refreshes before the first real request.
Paired with :class:`HermesTokenStorage` persisting an absolute
``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the
remaining TTL we compute here reflects real wall-clock age.
"""
await super()._initialize()
tokens = self.context.current_tokens
if tokens is not None and tokens.expires_in is not None:
self.context.update_token_expiry(tokens)
# Pre-flight OAuth AS discovery so ``_refresh_token`` has a
# correct ``token_endpoint`` before the first refresh attempt.
# Only runs when we have tokens on cold-load but no cached
# metadata — i.e. the exact scenario where the SDK's built-in
# 401-branch discovery hasn't had a chance to run yet.
if (
tokens is not None
and self.context.oauth_metadata is None
):
try:
await self._prefetch_oauth_metadata()
except Exception as exc: # pragma: no cover — defensive
# Non-fatal: if discovery fails, the SDK's normal 401-
# branch discovery will run on the next request.
logger.debug(
"MCP OAuth '%s': pre-flight metadata discovery "
"failed (non-fatal): %s",
self._hermes_server_name, exc,
)
async def _prefetch_oauth_metadata(self) -> None:
"""Fetch PRM + ASM from the well-known endpoints, cache on context.
Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551)
but runs synchronously before the first request instead of
inside the httpx auth_flow generator. Uses the SDK's own URL
builders and response handlers so we track whatever the SDK
version we're pinned to expects.
"""
import httpx # local import: httpx is an MCP SDK dependency
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_oauth_metadata_request,
handle_auth_metadata_response,
handle_protected_resource_response,
)
server_url = self.context.server_url
async with httpx.AsyncClient(timeout=10.0) as client:
# Step 1: PRM discovery to learn the authorization_server URL.
for url in build_protected_resource_metadata_discovery_urls(
None, server_url
):
req = create_oauth_metadata_request(url)
try:
resp = await client.send(req)
except httpx.HTTPError as exc:
logger.debug(
"MCP OAuth '%s': PRM discovery to %s failed: %s",
self._hermes_server_name, url, exc,
)
continue
prm = await handle_protected_resource_response(resp)
if prm:
self.context.protected_resource_metadata = prm
if prm.authorization_servers:
self.context.auth_server_url = str(
prm.authorization_servers[0]
)
break
# Step 2: ASM discovery against the auth_server_url (or
# server_url fallback for legacy providers).
for url in build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, server_url
):
req = create_oauth_metadata_request(url)
try:
resp = await client.send(req)
except httpx.HTTPError as exc:
logger.debug(
"MCP OAuth '%s': ASM discovery to %s failed: %s",
self._hermes_server_name, url, exc,
)
continue
ok, asm = await handle_auth_metadata_response(resp)
if not ok:
break
if asm:
self.context.oauth_metadata = asm
logger.debug(
"MCP OAuth '%s': pre-flight ASM discovered "
"token_endpoint=%s",
self._hermes_server_name, asm.token_endpoint,
)
break
async def async_auth_flow(self, request): # type: ignore[override] async def async_auth_flow(self, request): # type: ignore[override]
# Pre-flow hook: ask the manager to refresh from disk if needed. # Pre-flow hook: ask the manager to refresh from disk if needed.
# Any failure here is non-fatal — we just log and proceed with # Any failure here is non-fatal — we just log and proceed with