[verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load

PR #11383's consolidation fixed external-refresh reloading and 401 dedup
but left two latent bugs that surfaced on BetterStack and any other OAuth
MCP with a split-origin authorization server:

1. HermesTokenStorage persisted only a relative 'expires_in', which is
   meaningless after a process restart. The MCP SDK's OAuthContext
   does NOT seed token_expiry_time in _initialize, so is_token_valid()
   returned True for any reloaded token regardless of age. Expired
   tokens shipped to servers, and app-level auth failures (e.g.
   BetterStack's 'No teams found. Please check your authentication.')
   were invisible to the transport-layer 401 handler.

2. Even once preemptive refresh did fire, the SDK's _refresh_token
   falls back to {server_url}/token when oauth_metadata isn't cached.
   For providers whose AS is at a different origin (BetterStack:
   mcp.betterstack.com for MCP, betterstack.com/oauth/token for the
   token endpoint), that fallback 404s and drops into full browser
   re-auth on every process restart.

Fix set:

- HermesTokenStorage.set_tokens persists an absolute wall-clock
  expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL
  at write time).
- HermesTokenStorage.get_tokens reconstructs expires_in from
  max(expires_at - now, 0), clamping expired tokens to zero TTL.
  Legacy files without expires_at fall back to file-mtime as a
  best-effort wall-clock proxy, self-healing on the next set_tokens.
- HermesMCPOAuthProvider._initialize calls super(), then
  update_token_expiry on the reloaded tokens so token_expiry_time
  reflects actual remaining TTL. If tokens are loaded but
  oauth_metadata is missing, pre-flight PRM + ASM discovery runs
  via httpx.AsyncClient using the MCP SDK's own URL builders and
  response handlers (build_protected_resource_metadata_discovery_urls,
  handle_auth_metadata_response, etc.) so the SDK sees the correct
  token_endpoint before the first refresh attempt. Pre-flight is
  skipped when there are no stored tokens to keep fresh-install
  paths zero-cost.

Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py):
- set_tokens persists absolute expires_at
- set_tokens skips expires_at when token has no expires_in
- get_tokens round-trips expires_at -> remaining expires_in
- expired tokens reload with expires_in=0
- legacy files without expires_at fall back to mtime proxy
- _initialize seeds token_expiry_time from stored tokens
- _initialize flags expired-on-disk tokens as is_token_valid=False
- _initialize pre-flights PRM + ASM discovery with mock transport
- _initialize skips pre-flight when no tokens are stored

Verified against BetterStack MCP:
  hermes mcp test betterstack -> Connected (2508ms), 83 tools
  mcp_betterstack_telemetry_list_teams_tool -> real team data, not
    'No teams found. Please check your authentication.'

Reference: mcp-oauth-token-diagnosis skill, Fix A.
This commit is contained in:
Hermes Agent 2026-04-18 14:57:38 +10:00
parent 3eeab4bc06
commit cc06beaf13
3 changed files with 718 additions and 1 deletions

View file

@ -0,0 +1,546 @@
"""Tests for cold-load token expiry tracking in MCP OAuth.
PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch)
and 401 dedup, but left two underlying latent bugs in place:
1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``,
which is meaningless after a process restart.
2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from
storage but does NOT call ``update_token_expiry``, so
``token_expiry_time`` stays None. ``is_token_valid()`` then returns True
for any loaded token regardless of actual age, and the SDK's preemptive
refresh branch at ``oauth2.py:491`` is never taken.
Consequence: a token that expired while the process was down ships to the
server with a stale Bearer header. The server's response is provider-specific
some return HTTP 401 (caught by the consolidation's 401 handler, which
surfaces a ``needs_reauth`` error), others return HTTP 200 with an
application-level auth failure in the body (e.g. BetterStack's "No teams
found. Please check your authentication."), which the consolidation cannot
detect.
These tests pin the contract for Fix A:
- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp.
- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so
the SDK's ``update_token_expiry`` computes the correct absolute expiry.
- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time``
after loading, so ``is_token_valid()`` reports True only for tokens that
are actually still valid, and the SDK's preemptive refresh fires for
expired tokens with a live refresh_token.
Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute
timestamp persisted alongside the access_token (``auth.ts:~180``).
"""
from __future__ import annotations
import asyncio
import json
import time
import pytest
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
# ---------------------------------------------------------------------------
# HermesTokenStorage — absolute expiry persistence
# ---------------------------------------------------------------------------
class TestSetTokensAbsoluteExpiry:
def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch):
"""Tokens round-tripped through disk must encode absolute expiry."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
before = time.time()
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
)
after = time.time()
on_disk = json.loads(
(tmp_path / "mcp-tokens" / "srv.json").read_text()
)
assert "expires_at" in on_disk, (
"Fix A: set_tokens must record an absolute expires_at wall-clock "
"timestamp alongside the SDK's serialized token so cold-loads "
"can compute correct remaining TTL."
)
assert before + 3600 <= on_disk["expires_at"] <= after + 3600
def test_set_tokens_without_expires_in_omits_expires_at(
self, tmp_path, monkeypatch
):
"""Tokens without a TTL must not gain a fabricated expires_at."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
refresh_token="r",
)
)
)
on_disk = json.loads(
(tmp_path / "mcp-tokens" / "srv.json").read_text()
)
assert "expires_at" not in on_disk
class TestGetTokensReconstructsExpiresIn:
def test_get_tokens_uses_expires_at_for_remaining_ttl(
self, tmp_path, monkeypatch
):
"""Round-trip: expires_in on read must reflect time remaining."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
)
# Wait briefly so the remaining TTL is measurably less than 3600.
time.sleep(0.05)
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in is not None
# Should be slightly less than 3600 after the 50ms sleep.
assert 3500 < reloaded.expires_in <= 3600
def test_get_tokens_returns_zero_ttl_for_expired_token(
self, tmp_path, monkeypatch
):
"""An already-expired token reloaded from disk must report expires_in=0."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
# Write an already-expired token file directly.
(token_dir / "srv.json").write_text(
json.dumps(
{
"access_token": "a",
"token_type": "Bearer",
"expires_in": 3600,
"expires_at": time.time() - 60, # expired 1 min ago
"refresh_token": "r",
}
)
)
storage = HermesTokenStorage("srv")
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in == 0, (
"Expired token must reload with expires_in=0 so the SDK's "
"is_token_valid() returns False and preemptive refresh fires."
)
def test_get_tokens_legacy_file_without_expires_at_is_loadable(
self, tmp_path, monkeypatch
):
"""Existing on-disk files (pre-Fix-A) must still load without crashing.
Pre-existing token files have ``expires_in`` but no ``expires_at``.
Fix A falls back to the file's mtime as a best-effort wall-clock
proxy: a file whose (mtime + expires_in) is in the past clamps
expires_in to zero so the SDK refreshes on next request. A fresh
legacy-format file (mtime = now) keeps most of its TTL.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
# Legacy-shape file (no expires_at). Make it stale by backdating mtime
# well past its nominal expires_in.
legacy_path = token_dir / "srv.json"
legacy_path.write_text(
json.dumps(
{
"access_token": "a",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "r",
}
)
)
stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL
import os
os.utime(legacy_path, (stale_time, stale_time))
storage = HermesTokenStorage("srv")
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in == 0, (
"Legacy file whose mtime + expires_in is in the past must report "
"expires_in=0 so the SDK refreshes on next request."
)
# ---------------------------------------------------------------------------
# HermesMCPOAuthProvider._initialize — seed token_expiry_time
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_initialize_seeds_token_expiry_time_from_stored_tokens(
tmp_path, monkeypatch
):
"""Cold-load must populate context.token_expiry_time.
The SDK's base ``_initialize`` loads current_tokens but doesn't seed
token_expiry_time. Our subclass must do it so ``is_token_valid()``
reports correctly and the preemptive-refresh path fires when needed.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
storage = HermesTokenStorage("srv")
await storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=7200,
refresh_token="r",
)
)
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
from mcp.shared.auth import OAuthClientMetadata
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://example.com/mcp",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.token_expiry_time is not None, (
"Fix A: _initialize must seed context.token_expiry_time so "
"is_token_valid() correctly reports expiry on cold-load."
)
# Should be ~7200s in the future (fresh write).
assert provider.context.token_expiry_time > time.time() + 7000
assert provider.context.token_expiry_time <= time.time() + 7200 + 5
@pytest.mark.asyncio
async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch):
"""After _initialize, an expired-on-disk token must report is_token_valid=False.
This is the end-to-end assertion: cold-load an expired token, verify the
SDK's own ``is_token_valid()`` now returns False (the consequence of
seeding token_expiry_time correctly), so the SDK's ``async_auth_flow``
will take the ``can_refresh_token()`` branch on the next request and
silently refresh instead of sending the stale Bearer.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
# Write an already-expired token directly so we control the wall-clock.
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
(token_dir / "srv.json").write_text(
json.dumps(
{
"access_token": "stale",
"token_type": "Bearer",
"expires_in": 3600,
"expires_at": time.time() - 60,
"refresh_token": "fresh",
}
)
)
storage = HermesTokenStorage("srv")
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://example.com/mcp",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.is_token_valid() is False, (
"After _initialize with an expired-on-disk token, is_token_valid() "
"must return False so the SDK's async_auth_flow takes the "
"preemptive refresh path."
)
assert provider.context.can_refresh_token() is True, (
"Refresh should remain possible because refresh_token + client_info "
"are both present."
)
async def _noop_redirect(_url: str) -> None:
return None
async def _noop_callback() -> tuple[str, str | None]:
raise AssertionError("callback handler should not be invoked in these tests")
# ---------------------------------------------------------------------------
# Pre-flight OAuth metadata discovery
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_initialize_prefetches_oauth_metadata_when_missing(
tmp_path, monkeypatch
):
"""Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token``
has the correct ``token_endpoint`` before the first refresh attempt.
Without this, the SDK's ``_refresh_token`` falls back to
``{server_url}/token`` which is wrong for providers whose AS is at
a different origin. BetterStack specifically: MCP at
``mcp.betterstack.com`` but token_endpoint at
``betterstack.com/oauth/token``. Without pre-flight the refresh 404s
and we drop into full browser re-auth visible to the user as an
unwanted OAuth browser prompt every time the process restarts.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import httpx
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthToken,
)
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
storage = HermesTokenStorage("srv")
await storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
# Route the AsyncClient used inside _prefetch_oauth_metadata through a
# MockTransport that mimics BetterStack's split-origin discovery:
# PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com
# ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token
def mock_handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if url.endswith("/.well-known/oauth-protected-resource"):
return httpx.Response(
200,
json={
"resource": "https://mcp.example.com",
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read", "write"],
"bearer_methods_supported": ["header"],
},
)
if url.endswith("/.well-known/oauth-authorization-server"):
return httpx.Response(
200,
json={
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/oauth/authorize",
"token_endpoint": "https://auth.example.com/oauth/token",
"registration_endpoint": "https://auth.example.com/oauth/register",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": ["none"],
"scopes_supported": ["read", "write"],
},
)
return httpx.Response(404)
transport = httpx.MockTransport(mock_handler)
# Patch the AsyncClient constructor used by _prefetch_oauth_metadata so
# it uses our mock transport instead of the real network.
import httpx as real_httpx
original_async_client = real_httpx.AsyncClient
def patched_async_client(*args, **kwargs):
kwargs["transport"] = transport
return original_async_client(*args, **kwargs)
monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client)
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://mcp.example.com",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.protected_resource_metadata is not None, (
"Pre-flight must cache PRM for the SDK to reference later."
)
assert provider.context.oauth_metadata is not None, (
"Pre-flight must cache ASM so _refresh_token builds the correct "
"token_endpoint URL."
)
assert str(provider.context.oauth_metadata.token_endpoint) == (
"https://auth.example.com/oauth/token"
)
@pytest.mark.asyncio
async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch):
"""Pre-flight must not run when there are no stored tokens yet.
Without this guard, every fresh-install ``_initialize`` would do two
extra network roundtrips that gain nothing (the SDK's 401-branch
discovery will run on the first real request anyway).
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import httpx
from mcp.shared.auth import OAuthClientMetadata
from pydantic import AnyUrl
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
from tools.mcp_oauth import HermesTokenStorage
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
calls: list[str] = []
def mock_handler(request: httpx.Request) -> httpx.Response:
calls.append(str(request.url))
return httpx.Response(404)
transport = httpx.MockTransport(mock_handler)
import httpx as real_httpx
original = real_httpx.AsyncClient
def patched(*args, **kwargs):
kwargs["transport"] = transport
return original(*args, **kwargs)
monkeypatch.setattr(real_httpx, "AsyncClient", patched)
storage = HermesTokenStorage("srv") # empty — no tokens on disk
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://mcp.example.com",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert calls == [], (
f"Pre-flight must not fire when no tokens are stored, but got {calls}"
)

View file

@ -40,6 +40,7 @@ import re
import socket
import sys
import threading
import time
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
@ -196,6 +197,35 @@ class HermesTokenStorage:
data = _read_json(self._tokens_path())
if data is None:
return None
# Hermes records an absolute wall-clock ``expires_at`` alongside the
# SDK's serialized token (see ``set_tokens``). On read we rewrite
# ``expires_in`` to the remaining seconds so the SDK's downstream
# ``update_token_expiry`` computes the correct absolute time and
# ``is_token_valid()`` correctly reports False for tokens that
# expired while the process was down.
#
# Legacy token files (pre-Fix-A) have ``expires_in`` but no
# ``expires_at``. We fall back to the file's mtime as a best-effort
# wall-clock proxy for when the token was written: if (mtime +
# expires_in) is in the past, clamp ``expires_in`` to zero so the
# SDK refreshes before the first request. This self-heals one-time
# on the next successful ``set_tokens``, which writes the new
# ``expires_at`` field. The stored ``expires_at`` is stripped before
# model_validate because it's not part of the SDK's OAuthToken schema.
absolute_expiry = data.pop("expires_at", None)
if absolute_expiry is not None:
data["expires_in"] = int(max(absolute_expiry - time.time(), 0))
elif data.get("expires_in") is not None:
try:
file_mtime = self._tokens_path().stat().st_mtime
except OSError:
file_mtime = None
if file_mtime is not None:
try:
implied_expiry = file_mtime + int(data["expires_in"])
data["expires_in"] = int(max(implied_expiry - time.time(), 0))
except (TypeError, ValueError):
pass
try:
return OAuthToken.model_validate(data)
except (ValueError, TypeError, KeyError) as exc:
@ -203,7 +233,23 @@ class HermesTokenStorage:
return None
async def set_tokens(self, tokens: "OAuthToken") -> None:
_write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
payload = tokens.model_dump(exclude_none=True)
# Persist an absolute ``expires_at`` so a process restart can
# reconstruct the correct remaining TTL. Without this the MCP SDK's
# ``_initialize`` reloads a relative ``expires_in`` which has no
# wall-clock reference, leaving ``context.token_expiry_time=None``
# and ``is_token_valid()`` falsely reporting True. See Fix A in
# ``mcp-oauth-token-diagnosis`` skill + Claude Code's
# ``OAuthTokens.expiresAt`` persistence (auth.ts ~180).
expires_in = payload.get("expires_in")
if expires_in is not None:
try:
payload["expires_at"] = time.time() + int(expires_in)
except (TypeError, ValueError):
# Mock tokens or unusual shapes: skip the expires_at write
# rather than fail persistence.
pass
_write_json(self._tokens_path(), payload)
logger.debug("OAuth tokens saved for %s", self._server_name)
# -- client info -------------------------------------------------------

View file

@ -111,6 +111,131 @@ def _make_hermes_provider_class() -> Optional[type]:
super().__init__(*args, **kwargs)
self._hermes_server_name = server_name
async def _initialize(self) -> None:
"""Load stored tokens + client info AND seed token_expiry_time.
Also eagerly fetches OAuth authorization-server metadata (PRM +
ASM) when we have stored tokens but no cached metadata, so the
SDK's ``_refresh_token`` can build the correct token_endpoint
URL on the preemptive-refresh path. Without this, the SDK
falls back to ``{mcp_server_url}/token`` (wrong for providers
whose AS is a different origin BetterStack's MCP lives at
``https://mcp.betterstack.com`` but its token endpoint is at
``https://betterstack.com/oauth/token``), the refresh 404s, and
we drop through to full browser reauth.
The SDK's base ``_initialize`` populates ``current_tokens`` but
does NOT call ``update_token_expiry``, so ``token_expiry_time``
stays ``None`` and ``is_token_valid()`` returns True for any
loaded token regardless of actual age. After a process restart
this ships stale Bearer tokens to the server; some providers
return HTTP 401 (caught by the 401 handler), others return 200
with an app-level auth error (invisible to the transport layer,
e.g. BetterStack returning "No teams found. Please check your
authentication.").
Seeding ``token_expiry_time`` from the reloaded token fixes that:
``is_token_valid()`` correctly reports False for expired tokens,
``async_auth_flow`` takes the ``can_refresh_token()`` branch,
and the SDK quietly refreshes before the first real request.
Paired with :class:`HermesTokenStorage` persisting an absolute
``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the
remaining TTL we compute here reflects real wall-clock age.
"""
await super()._initialize()
tokens = self.context.current_tokens
if tokens is not None and tokens.expires_in is not None:
self.context.update_token_expiry(tokens)
# Pre-flight OAuth AS discovery so ``_refresh_token`` has a
# correct ``token_endpoint`` before the first refresh attempt.
# Only runs when we have tokens on cold-load but no cached
# metadata — i.e. the exact scenario where the SDK's built-in
# 401-branch discovery hasn't had a chance to run yet.
if (
tokens is not None
and self.context.oauth_metadata is None
):
try:
await self._prefetch_oauth_metadata()
except Exception as exc: # pragma: no cover — defensive
# Non-fatal: if discovery fails, the SDK's normal 401-
# branch discovery will run on the next request.
logger.debug(
"MCP OAuth '%s': pre-flight metadata discovery "
"failed (non-fatal): %s",
self._hermes_server_name, exc,
)
async def _prefetch_oauth_metadata(self) -> None:
"""Fetch PRM + ASM from the well-known endpoints, cache on context.
Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551)
but runs synchronously before the first request instead of
inside the httpx auth_flow generator. Uses the SDK's own URL
builders and response handlers so we track whatever the SDK
version we're pinned to expects.
"""
import httpx # local import: httpx is an MCP SDK dependency
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_oauth_metadata_request,
handle_auth_metadata_response,
handle_protected_resource_response,
)
server_url = self.context.server_url
async with httpx.AsyncClient(timeout=10.0) as client:
# Step 1: PRM discovery to learn the authorization_server URL.
for url in build_protected_resource_metadata_discovery_urls(
None, server_url
):
req = create_oauth_metadata_request(url)
try:
resp = await client.send(req)
except httpx.HTTPError as exc:
logger.debug(
"MCP OAuth '%s': PRM discovery to %s failed: %s",
self._hermes_server_name, url, exc,
)
continue
prm = await handle_protected_resource_response(resp)
if prm:
self.context.protected_resource_metadata = prm
if prm.authorization_servers:
self.context.auth_server_url = str(
prm.authorization_servers[0]
)
break
# Step 2: ASM discovery against the auth_server_url (or
# server_url fallback for legacy providers).
for url in build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, server_url
):
req = create_oauth_metadata_request(url)
try:
resp = await client.send(req)
except httpx.HTTPError as exc:
logger.debug(
"MCP OAuth '%s': ASM discovery to %s failed: %s",
self._hermes_server_name, url, exc,
)
continue
ok, asm = await handle_auth_metadata_response(resp)
if not ok:
break
if asm:
self.context.oauth_metadata = asm
logger.debug(
"MCP OAuth '%s': pre-flight ASM discovered "
"token_endpoint=%s",
self._hermes_server_name, asm.token_endpoint,
)
break
async def async_auth_flow(self, request): # type: ignore[override]
# Pre-flow hook: ask the manager to refresh from disk if needed.
# Any failure here is non-fatal — we just log and proceed with