hermes-agent/tests/tools/test_mcp_oauth_cold_load_expiry.py
Teknium a3a4932405
fix(mcp-oauth): bidirectional auth_flow bridge + absolute expires_at (salvage #12025) (#12717)
* [verified] fix(mcp-oauth): bridge httpx auth_flow bidirectional generator

HermesMCPOAuthProvider.async_auth_flow wrapped the SDK's auth_flow with
'async for item in super().async_auth_flow(request): yield item', which
discards httpx's .asend(response) values and resumes the inner generator
with None. This broke every OAuth MCP server on the first HTTP response
with 'NoneType' object has no attribute 'status_code' crashing at
mcp/client/auth/oauth2.py:505.

Replace with a manual bridge that forwards .asend() values into the
inner generator, preserving httpx's bidirectional auth_flow contract.

Add tests/tools/test_mcp_oauth_bidirectional.py with two regression
tests that drive the flow through real .asend() round-trips. These
catch the bug at the unit level; prior tests only exercised
_initialize() and disk-watching, never the full generator protocol.

Verified against BetterStack MCP:
  Before: 'Connection failed (11564ms): NoneType...' after 3 retries
  After:  'Connected (2416ms); Tools discovered: 83'

Regression from #11383.

* [verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load

PR #11383's consolidation fixed external-refresh reloading and 401 dedup
but left two latent bugs that surfaced on BetterStack and any other OAuth
MCP with a split-origin authorization server:

1. HermesTokenStorage persisted only a relative 'expires_in', which is
   meaningless after a process restart. The MCP SDK's OAuthContext
   does NOT seed token_expiry_time in _initialize, so is_token_valid()
   returned True for any reloaded token regardless of age. Expired
   tokens shipped to servers, and app-level auth failures (e.g.
   BetterStack's 'No teams found. Please check your authentication.')
   were invisible to the transport-layer 401 handler.

2. Even once preemptive refresh did fire, the SDK's _refresh_token
   falls back to {server_url}/token when oauth_metadata isn't cached.
   For providers whose AS is at a different origin (BetterStack:
   mcp.betterstack.com for MCP, betterstack.com/oauth/token for the
   token endpoint), that fallback 404s and drops into full browser
   re-auth on every process restart.

Fix set:

- HermesTokenStorage.set_tokens persists an absolute wall-clock
  expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL
  at write time).
- HermesTokenStorage.get_tokens reconstructs expires_in from
  max(expires_at - now, 0), clamping expired tokens to zero TTL.
  Legacy files without expires_at fall back to file-mtime as a
  best-effort wall-clock proxy, self-healing on the next set_tokens.
- HermesMCPOAuthProvider._initialize calls super(), then
  update_token_expiry on the reloaded tokens so token_expiry_time
  reflects actual remaining TTL. If tokens are loaded but
  oauth_metadata is missing, pre-flight PRM + ASM discovery runs
  via httpx.AsyncClient using the MCP SDK's own URL builders and
  response handlers (build_protected_resource_metadata_discovery_urls,
  handle_auth_metadata_response, etc.) so the SDK sees the correct
  token_endpoint before the first refresh attempt. Pre-flight is
  skipped when there are no stored tokens to keep fresh-install
  paths zero-cost.

Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py):
- set_tokens persists absolute expires_at
- set_tokens skips expires_at when token has no expires_in
- get_tokens round-trips expires_at -> remaining expires_in
- expired tokens reload with expires_in=0
- legacy files without expires_at fall back to mtime proxy
- _initialize seeds token_expiry_time from stored tokens
- _initialize flags expired-on-disk tokens as is_token_valid=False
- _initialize pre-flights PRM + ASM discovery with mock transport
- _initialize skips pre-flight when no tokens are stored

Verified against BetterStack MCP:
  hermes mcp test betterstack -> Connected (2508ms), 83 tools
  mcp_betterstack_telemetry_list_teams_tool -> real team data, not
    'No teams found. Please check your authentication.'

Reference: mcp-oauth-token-diagnosis skill, Fix A.

* chore: map hermes@noushq.ai to benbarclay in AUTHOR_MAP

Needed for CI attribution check on cherry-picked commits from PR #12025.

---------

Co-authored-by: Hermes Agent <hermes@noushq.ai>
2026-04-19 16:31:07 -07:00

546 lines
20 KiB
Python

"""Tests for cold-load token expiry tracking in MCP OAuth.
PR #11383's consolidation fixed external-refresh reloading (mtime disk-watch)
and 401 dedup, but left two underlying latent bugs in place:
1. ``HermesTokenStorage.set_tokens`` persisted only relative ``expires_in``,
which is meaningless after a process restart.
2. The MCP SDK's ``OAuthContext._initialize`` loads ``current_tokens`` from
storage but does NOT call ``update_token_expiry``, so
``token_expiry_time`` stays None. ``is_token_valid()`` then returns True
for any loaded token regardless of actual age, and the SDK's preemptive
refresh branch at ``oauth2.py:491`` is never taken.
Consequence: a token that expired while the process was down ships to the
server with a stale Bearer header. The server's response is provider-specific
— some return HTTP 401 (caught by the consolidation's 401 handler, which
surfaces a ``needs_reauth`` error), others return HTTP 200 with an
application-level auth failure in the body (e.g. BetterStack's "No teams
found. Please check your authentication."), which the consolidation cannot
detect.
These tests pin the contract for Fix A:
- ``set_tokens`` persists an absolute ``expires_at`` wall-clock timestamp.
- ``get_tokens`` reconstructs ``expires_in`` from ``expires_at - now`` so
the SDK's ``update_token_expiry`` computes the correct absolute expiry.
- ``HermesMCPOAuthProvider._initialize`` seeds ``context.token_expiry_time``
after loading, so ``is_token_valid()`` reports True only for tokens that
are actually still valid, and the SDK's preemptive refresh fires for
expired tokens with a live refresh_token.
Reference: Claude Code solves this via an ``OAuthTokens.expiresAt`` absolute
timestamp persisted alongside the access_token (``auth.ts:~180``).
"""
from __future__ import annotations
import asyncio
import json
import time
import pytest
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
# ---------------------------------------------------------------------------
# HermesTokenStorage — absolute expiry persistence
# ---------------------------------------------------------------------------
class TestSetTokensAbsoluteExpiry:
def test_set_tokens_persists_absolute_expires_at(self, tmp_path, monkeypatch):
"""Tokens round-tripped through disk must encode absolute expiry."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
before = time.time()
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
)
after = time.time()
on_disk = json.loads(
(tmp_path / "mcp-tokens" / "srv.json").read_text()
)
assert "expires_at" in on_disk, (
"Fix A: set_tokens must record an absolute expires_at wall-clock "
"timestamp alongside the SDK's serialized token so cold-loads "
"can compute correct remaining TTL."
)
assert before + 3600 <= on_disk["expires_at"] <= after + 3600
def test_set_tokens_without_expires_in_omits_expires_at(
self, tmp_path, monkeypatch
):
"""Tokens without a TTL must not gain a fabricated expires_at."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
refresh_token="r",
)
)
)
on_disk = json.loads(
(tmp_path / "mcp-tokens" / "srv.json").read_text()
)
assert "expires_at" not in on_disk
class TestGetTokensReconstructsExpiresIn:
def test_get_tokens_uses_expires_at_for_remaining_ttl(
self, tmp_path, monkeypatch
):
"""Round-trip: expires_in on read must reflect time remaining."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthToken
from tools.mcp_oauth import HermesTokenStorage
storage = HermesTokenStorage("srv")
asyncio.run(
storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
)
# Wait briefly so the remaining TTL is measurably less than 3600.
time.sleep(0.05)
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in is not None
# Should be slightly less than 3600 after the 50ms sleep.
assert 3500 < reloaded.expires_in <= 3600
def test_get_tokens_returns_zero_ttl_for_expired_token(
self, tmp_path, monkeypatch
):
"""An already-expired token reloaded from disk must report expires_in=0."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
# Write an already-expired token file directly.
(token_dir / "srv.json").write_text(
json.dumps(
{
"access_token": "a",
"token_type": "Bearer",
"expires_in": 3600,
"expires_at": time.time() - 60, # expired 1 min ago
"refresh_token": "r",
}
)
)
storage = HermesTokenStorage("srv")
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in == 0, (
"Expired token must reload with expires_in=0 so the SDK's "
"is_token_valid() returns False and preemptive refresh fires."
)
def test_get_tokens_legacy_file_without_expires_at_is_loadable(
self, tmp_path, monkeypatch
):
"""Existing on-disk files (pre-Fix-A) must still load without crashing.
Pre-existing token files have ``expires_in`` but no ``expires_at``.
Fix A falls back to the file's mtime as a best-effort wall-clock
proxy: a file whose (mtime + expires_in) is in the past clamps
expires_in to zero so the SDK refreshes on next request. A fresh
legacy-format file (mtime = now) keeps most of its TTL.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
# Legacy-shape file (no expires_at). Make it stale by backdating mtime
# well past its nominal expires_in.
legacy_path = token_dir / "srv.json"
legacy_path.write_text(
json.dumps(
{
"access_token": "a",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "r",
}
)
)
stale_time = time.time() - 7200 # 2hr ago, exceeds 3600s TTL
import os
os.utime(legacy_path, (stale_time, stale_time))
storage = HermesTokenStorage("srv")
reloaded = asyncio.run(storage.get_tokens())
assert reloaded is not None
assert reloaded.expires_in == 0, (
"Legacy file whose mtime + expires_in is in the past must report "
"expires_in=0 so the SDK refreshes on next request."
)
# ---------------------------------------------------------------------------
# HermesMCPOAuthProvider._initialize — seed token_expiry_time
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_initialize_seeds_token_expiry_time_from_stored_tokens(
tmp_path, monkeypatch
):
"""Cold-load must populate context.token_expiry_time.
The SDK's base ``_initialize`` loads current_tokens but doesn't seed
token_expiry_time. Our subclass must do it so ``is_token_valid()``
reports correctly and the preemptive-refresh path fires when needed.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
storage = HermesTokenStorage("srv")
await storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=7200,
refresh_token="r",
)
)
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
from mcp.shared.auth import OAuthClientMetadata
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://example.com/mcp",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.token_expiry_time is not None, (
"Fix A: _initialize must seed context.token_expiry_time so "
"is_token_valid() correctly reports expiry on cold-load."
)
# Should be ~7200s in the future (fresh write).
assert provider.context.token_expiry_time > time.time() + 7000
assert provider.context.token_expiry_time <= time.time() + 7200 + 5
@pytest.mark.asyncio
async def test_initialize_flags_expired_token_as_invalid(tmp_path, monkeypatch):
"""After _initialize, an expired-on-disk token must report is_token_valid=False.
This is the end-to-end assertion: cold-load an expired token, verify the
SDK's own ``is_token_valid()`` now returns False (the consequence of
seeding token_expiry_time correctly), so the SDK's ``async_auth_flow``
will take the ``can_refresh_token()`` branch on the next request and
silently refresh instead of sending the stale Bearer.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage, _get_token_dir
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
# Write an already-expired token directly so we control the wall-clock.
token_dir = _get_token_dir()
token_dir.mkdir(parents=True, exist_ok=True)
(token_dir / "srv.json").write_text(
json.dumps(
{
"access_token": "stale",
"token_type": "Bearer",
"expires_in": 3600,
"expires_at": time.time() - 60,
"refresh_token": "fresh",
}
)
)
storage = HermesTokenStorage("srv")
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://example.com/mcp",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.is_token_valid() is False, (
"After _initialize with an expired-on-disk token, is_token_valid() "
"must return False so the SDK's async_auth_flow takes the "
"preemptive refresh path."
)
assert provider.context.can_refresh_token() is True, (
"Refresh should remain possible because refresh_token + client_info "
"are both present."
)
async def _noop_redirect(_url: str) -> None:
return None
async def _noop_callback() -> tuple[str, str | None]:
raise AssertionError("callback handler should not be invoked in these tests")
# ---------------------------------------------------------------------------
# Pre-flight OAuth metadata discovery
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_initialize_prefetches_oauth_metadata_when_missing(
tmp_path, monkeypatch
):
"""Cold-load must pre-flight PRM + ASM discovery so ``_refresh_token``
has the correct ``token_endpoint`` before the first refresh attempt.
Without this, the SDK's ``_refresh_token`` falls back to
``{server_url}/token`` which is wrong for providers whose AS is at
a different origin. BetterStack specifically: MCP at
``mcp.betterstack.com`` but token_endpoint at
``betterstack.com/oauth/token``. Without pre-flight the refresh 404s
and we drop into full browser re-auth — visible to the user as an
unwanted OAuth browser prompt every time the process restarts.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import httpx
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthToken,
)
from pydantic import AnyUrl
from tools.mcp_oauth import HermesTokenStorage
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
storage = HermesTokenStorage("srv")
await storage.set_tokens(
OAuthToken(
access_token="a",
token_type="Bearer",
expires_in=3600,
refresh_token="r",
)
)
await storage.set_client_info(
OAuthClientInformationFull(
client_id="test-client",
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
)
)
# Route the AsyncClient used inside _prefetch_oauth_metadata through a
# MockTransport that mimics BetterStack's split-origin discovery:
# PRM at mcp.example.com/.well-known/oauth-protected-resource -> points to auth.example.com
# ASM at auth.example.com/.well-known/oauth-authorization-server -> token_endpoint at auth.example.com/oauth/token
def mock_handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if url.endswith("/.well-known/oauth-protected-resource"):
return httpx.Response(
200,
json={
"resource": "https://mcp.example.com",
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read", "write"],
"bearer_methods_supported": ["header"],
},
)
if url.endswith("/.well-known/oauth-authorization-server"):
return httpx.Response(
200,
json={
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/oauth/authorize",
"token_endpoint": "https://auth.example.com/oauth/token",
"registration_endpoint": "https://auth.example.com/oauth/register",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": ["none"],
"scopes_supported": ["read", "write"],
},
)
return httpx.Response(404)
transport = httpx.MockTransport(mock_handler)
# Patch the AsyncClient constructor used by _prefetch_oauth_metadata so
# it uses our mock transport instead of the real network.
import httpx as real_httpx
original_async_client = real_httpx.AsyncClient
def patched_async_client(*args, **kwargs):
kwargs["transport"] = transport
return original_async_client(*args, **kwargs)
monkeypatch.setattr(real_httpx, "AsyncClient", patched_async_client)
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://mcp.example.com",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert provider.context.protected_resource_metadata is not None, (
"Pre-flight must cache PRM for the SDK to reference later."
)
assert provider.context.oauth_metadata is not None, (
"Pre-flight must cache ASM so _refresh_token builds the correct "
"token_endpoint URL."
)
assert str(provider.context.oauth_metadata.token_endpoint) == (
"https://auth.example.com/oauth/token"
)
@pytest.mark.asyncio
async def test_initialize_skips_prefetch_when_no_tokens(tmp_path, monkeypatch):
"""Pre-flight must not run when there are no stored tokens yet.
Without this guard, every fresh-install ``_initialize`` would do two
extra network roundtrips that gain nothing (the SDK's 401-branch
discovery will run on the first real request anyway).
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
import httpx
from mcp.shared.auth import OAuthClientMetadata
from pydantic import AnyUrl
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS, reset_manager_for_tests
from tools.mcp_oauth import HermesTokenStorage
assert _HERMES_PROVIDER_CLS is not None
reset_manager_for_tests()
calls: list[str] = []
def mock_handler(request: httpx.Request) -> httpx.Response:
calls.append(str(request.url))
return httpx.Response(404)
transport = httpx.MockTransport(mock_handler)
import httpx as real_httpx
original = real_httpx.AsyncClient
def patched(*args, **kwargs):
kwargs["transport"] = transport
return original(*args, **kwargs)
monkeypatch.setattr(real_httpx, "AsyncClient", patched)
storage = HermesTokenStorage("srv") # empty — no tokens on disk
metadata = OAuthClientMetadata(
redirect_uris=[AnyUrl("http://127.0.0.1:12345/callback")],
client_name="Hermes Agent",
)
provider = _HERMES_PROVIDER_CLS(
server_name="srv",
server_url="https://mcp.example.com",
client_metadata=metadata,
storage=storage,
redirect_handler=_noop_redirect,
callback_handler=_noop_callback,
)
await provider._initialize()
assert calls == [], (
f"Pre-flight must not fire when no tokens are stored, but got {calls}"
)