hermes-agent/tools/mcp_oauth_manager.py
Teknium a3a4932405
fix(mcp-oauth): bidirectional auth_flow bridge + absolute expires_at (salvage #12025) (#12717)
* [verified] fix(mcp-oauth): bridge httpx auth_flow bidirectional generator

HermesMCPOAuthProvider.async_auth_flow wrapped the SDK's auth_flow with
'async for item in super().async_auth_flow(request): yield item', which
discards httpx's .asend(response) values and resumes the inner generator
with None. This broke every OAuth MCP server on the first HTTP response
with 'NoneType' object has no attribute 'status_code' crashing at
mcp/client/auth/oauth2.py:505.

Replace with a manual bridge that forwards .asend() values into the
inner generator, preserving httpx's bidirectional auth_flow contract.

Add tests/tools/test_mcp_oauth_bidirectional.py with two regression
tests that drive the flow through real .asend() round-trips. These
catch the bug at the unit level; prior tests only exercised
_initialize() and disk-watching, never the full generator protocol.

Verified against BetterStack MCP:
  Before: 'Connection failed (11564ms): NoneType...' after 3 retries
  After:  'Connected (2416ms); Tools discovered: 83'

Regression from #11383.

* [verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load

PR #11383's consolidation fixed external-refresh reloading and 401 dedup
but left two latent bugs that surfaced on BetterStack and any other OAuth
MCP with a split-origin authorization server:

1. HermesTokenStorage persisted only a relative 'expires_in', which is
   meaningless after a process restart. The MCP SDK's OAuthContext
   does NOT seed token_expiry_time in _initialize, so is_token_valid()
   returned True for any reloaded token regardless of age. Expired
   tokens shipped to servers, and app-level auth failures (e.g.
   BetterStack's 'No teams found. Please check your authentication.')
   were invisible to the transport-layer 401 handler.

2. Even once preemptive refresh did fire, the SDK's _refresh_token
   falls back to {server_url}/token when oauth_metadata isn't cached.
   For providers whose AS is at a different origin (BetterStack:
   mcp.betterstack.com for MCP, betterstack.com/oauth/token for the
   token endpoint), that fallback 404s and drops into full browser
   re-auth on every process restart.

Fix set:

- HermesTokenStorage.set_tokens persists an absolute wall-clock
  expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL
  at write time).
- HermesTokenStorage.get_tokens reconstructs expires_in from
  max(expires_at - now, 0), clamping expired tokens to zero TTL.
  Legacy files without expires_at fall back to file-mtime as a
  best-effort wall-clock proxy, self-healing on the next set_tokens.
- HermesMCPOAuthProvider._initialize calls super(), then
  update_token_expiry on the reloaded tokens so token_expiry_time
  reflects actual remaining TTL. If tokens are loaded but
  oauth_metadata is missing, pre-flight PRM + ASM discovery runs
  via httpx.AsyncClient using the MCP SDK's own URL builders and
  response handlers (build_protected_resource_metadata_discovery_urls,
  handle_auth_metadata_response, etc.) so the SDK sees the correct
  token_endpoint before the first refresh attempt. Pre-flight is
  skipped when there are no stored tokens to keep fresh-install
  paths zero-cost.

Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py):
- set_tokens persists absolute expires_at
- set_tokens skips expires_at when token has no expires_in
- get_tokens round-trips expires_at -> remaining expires_in
- expired tokens reload with expires_in=0
- legacy files without expires_at fall back to mtime proxy
- _initialize seeds token_expiry_time from stored tokens
- _initialize flags expired-on-disk tokens as is_token_valid=False
- _initialize pre-flights PRM + ASM discovery with mock transport
- _initialize skips pre-flight when no tokens are stored

Verified against BetterStack MCP:
  hermes mcp test betterstack -> Connected (2508ms), 83 tools
  mcp_betterstack_telemetry_list_teams_tool -> real team data, not
    'No teams found. Please check your authentication.'

Reference: mcp-oauth-token-diagnosis skill, Fix A.

* chore: map hermes@noushq.ai to benbarclay in AUTHOR_MAP

Needed for CI attribution check on cherry-picked commits from PR #12025.

---------

Co-authored-by: Hermes Agent <hermes@noushq.ai>
2026-04-19 16:31:07 -07:00

557 lines
23 KiB
Python

#!/usr/bin/env python3
"""Central manager for per-server MCP OAuth state.
One instance shared across the process. Holds per-server OAuth provider
instances and coordinates:
- **Cross-process token reload** via mtime-based disk watch. When an external
process (e.g. a user cron job) refreshes tokens on disk, the next auth flow
picks them up without requiring a process restart.
- **401 deduplication** via in-flight futures. When N concurrent tool calls
all hit 401 with the same access_token, only one recovery attempt fires;
the rest await the same result.
- **Reconnect signalling** for long-lived MCP sessions. The manager itself
does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does —
but the manager is the single source of truth that decides when reconnect
is warranted.
Replaces what used to be scattered across eight call sites in `mcp_oauth.py`,
`mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place
that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths
go through `get_manager()`.
Design reference:
- Claude Code's ``invalidateOAuthCacheIfDiskChanged``
(``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical
external-refresh staleness bug class.
- Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed``
(``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's
lazy refresh rather than calling refresh before every op, because one
``stat()`` per tool call is cheaper than an ``await`` + potential refresh
round-trip, and the SDK's in-memory expiry path is already correct.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from dataclasses import dataclass, field
from typing import Any, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Per-server entry
# ---------------------------------------------------------------------------
@dataclass
class _ProviderEntry:
"""Per-server OAuth state tracked by the manager.
Fields:
server_url: The MCP server URL used to build the provider. Tracked
so we can discard a cached provider if the URL changes.
oauth_config: Optional dict from ``mcp_servers.<name>.oauth``.
provider: The ``httpx.Auth``-compatible provider wrapping the MCP
SDK. None until first use.
last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file.
Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed`
to detect external refreshes.
lock: Serialises concurrent access to this entry's state. Bound to
whichever asyncio loop first awaits it (the MCP event loop).
pending_401: In-flight 401-handler futures keyed by the failed
access_token, for deduplicating thundering-herd 401s. Mirrors
Claude Code's ``pending401Handlers`` map.
"""
server_url: str
oauth_config: Optional[dict]
provider: Optional[Any] = None
last_mtime_ns: int = 0
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch
# ---------------------------------------------------------------------------
def _make_hermes_provider_class() -> Optional[type]:
"""Lazy-import the SDK base class and return our subclass.
Wrapped in a function so this module imports cleanly even when the
MCP SDK's OAuth module is unavailable (e.g. older mcp versions).
"""
try:
from mcp.client.auth.oauth2 import OAuthClientProvider
except ImportError: # pragma: no cover — SDK required in CI
return None
class HermesMCPOAuthProvider(OAuthClientProvider):
"""OAuthClientProvider with pre-flow disk-mtime reload.
Before every ``async_auth_flow`` invocation, asks the manager to
check whether the tokens file on disk has been modified externally.
If so, the manager resets ``_initialized`` so the next flow
re-reads from storage.
This makes external-process refreshes (cron, another CLI instance)
visible to the running MCP session without requiring a restart.
Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged``
(``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
"""
def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
super().__init__(*args, **kwargs)
self._hermes_server_name = server_name
async def _initialize(self) -> None:
"""Load stored tokens + client info AND seed token_expiry_time.
Also eagerly fetches OAuth authorization-server metadata (PRM +
ASM) when we have stored tokens but no cached metadata, so the
SDK's ``_refresh_token`` can build the correct token_endpoint
URL on the preemptive-refresh path. Without this, the SDK
falls back to ``{mcp_server_url}/token`` (wrong for providers
whose AS is a different origin — BetterStack's MCP lives at
``https://mcp.betterstack.com`` but its token endpoint is at
``https://betterstack.com/oauth/token``), the refresh 404s, and
we drop through to full browser reauth.
The SDK's base ``_initialize`` populates ``current_tokens`` but
does NOT call ``update_token_expiry``, so ``token_expiry_time``
stays ``None`` and ``is_token_valid()`` returns True for any
loaded token regardless of actual age. After a process restart
this ships stale Bearer tokens to the server; some providers
return HTTP 401 (caught by the 401 handler), others return 200
with an app-level auth error (invisible to the transport layer,
e.g. BetterStack returning "No teams found. Please check your
authentication.").
Seeding ``token_expiry_time`` from the reloaded token fixes that:
``is_token_valid()`` correctly reports False for expired tokens,
``async_auth_flow`` takes the ``can_refresh_token()`` branch,
and the SDK quietly refreshes before the first real request.
Paired with :class:`HermesTokenStorage` persisting an absolute
``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the
remaining TTL we compute here reflects real wall-clock age.
"""
await super()._initialize()
tokens = self.context.current_tokens
if tokens is not None and tokens.expires_in is not None:
self.context.update_token_expiry(tokens)
# Pre-flight OAuth AS discovery so ``_refresh_token`` has a
# correct ``token_endpoint`` before the first refresh attempt.
# Only runs when we have tokens on cold-load but no cached
# metadata — i.e. the exact scenario where the SDK's built-in
# 401-branch discovery hasn't had a chance to run yet.
if (
tokens is not None
and self.context.oauth_metadata is None
):
try:
await self._prefetch_oauth_metadata()
except Exception as exc: # pragma: no cover — defensive
# Non-fatal: if discovery fails, the SDK's normal 401-
# branch discovery will run on the next request.
logger.debug(
"MCP OAuth '%s': pre-flight metadata discovery "
"failed (non-fatal): %s",
self._hermes_server_name, exc,
)
async def _prefetch_oauth_metadata(self) -> None:
"""Fetch PRM + ASM from the well-known endpoints, cache on context.
Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551)
but runs synchronously before the first request instead of
inside the httpx auth_flow generator. Uses the SDK's own URL
builders and response handlers so we track whatever the SDK
version we're pinned to expects.
"""
import httpx # local import: httpx is an MCP SDK dependency
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_oauth_metadata_request,
handle_auth_metadata_response,
handle_protected_resource_response,
)
server_url = self.context.server_url
async with httpx.AsyncClient(timeout=10.0) as client:
# Step 1: PRM discovery to learn the authorization_server URL.
for url in build_protected_resource_metadata_discovery_urls(
None, server_url
):
req = create_oauth_metadata_request(url)
try:
resp = await client.send(req)
except httpx.HTTPError as exc:
logger.debug(
"MCP OAuth '%s': PRM discovery to %s failed: %s",
self._hermes_server_name, url, exc,
)
continue
prm = await handle_protected_resource_response(resp)
if prm:
self.context.protected_resource_metadata = prm
if prm.authorization_servers:
self.context.auth_server_url = str(
prm.authorization_servers[0]
)
break
# Step 2: ASM discovery against the auth_server_url (or
# server_url fallback for legacy providers).
for url in build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, server_url
):
req = create_oauth_metadata_request(url)
try:
resp = await client.send(req)
except httpx.HTTPError as exc:
logger.debug(
"MCP OAuth '%s': ASM discovery to %s failed: %s",
self._hermes_server_name, url, exc,
)
continue
ok, asm = await handle_auth_metadata_response(resp)
if not ok:
break
if asm:
self.context.oauth_metadata = asm
logger.debug(
"MCP OAuth '%s': pre-flight ASM discovered "
"token_endpoint=%s",
self._hermes_server_name, asm.token_endpoint,
)
break
async def async_auth_flow(self, request): # type: ignore[override]
# Pre-flow hook: ask the manager to refresh from disk if needed.
# Any failure here is non-fatal — we just log and proceed with
# whatever state the SDK already has.
try:
await get_manager().invalidate_if_disk_changed(
self._hermes_server_name
)
except Exception as exc: # pragma: no cover — defensive
logger.debug(
"MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s",
self._hermes_server_name, exc,
)
# Manually bridge the bidirectional generator protocol. httpx's
# auth_flow driver (httpx._client._send_handling_auth) calls
# ``auth_flow.asend(response)`` to feed HTTP responses back into
# the generator. A naive wrapper using ``async for item in inner:
# yield item`` DISCARDS those .asend(response) values and resumes
# the inner generator with None, so the SDK's
# ``response = yield request`` branch in
# mcp/client/auth/oauth2.py sees response=None and crashes at
# ``if response.status_code == 401`` with AttributeError.
#
# The bridge below forwards each .asend() value into the inner
# generator via inner.asend(incoming), preserving the bidirectional
# contract. Regression from PR #11383 caught by
# tests/tools/test_mcp_oauth_bidirectional.py.
inner = super().async_auth_flow(request)
try:
outgoing = await inner.__anext__()
while True:
incoming = yield outgoing
outgoing = await inner.asend(incoming)
except StopAsyncIteration:
return
return HermesMCPOAuthProvider
# Cached at import time. Tested and used by :class:`MCPOAuthManager`.
_HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class()
# ---------------------------------------------------------------------------
# Manager
# ---------------------------------------------------------------------------
class MCPOAuthManager:
"""Single source of truth for per-server MCP OAuth state.
Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for
get-or-create semantics. Per-entry state is guarded by the entry's own
``asyncio.Lock`` (used from the MCP event loop thread).
"""
def __init__(self) -> None:
self._entries: dict[str, _ProviderEntry] = {}
self._entries_lock = threading.Lock()
# -- Provider construction / caching -------------------------------------
def get_or_build_provider(
self,
server_name: str,
server_url: str,
oauth_config: Optional[dict],
) -> Optional[Any]:
"""Return a cached OAuth provider for ``server_name`` or build one.
Idempotent: repeat calls with the same name return the same instance.
If ``server_url`` changes for a given name, the cached entry is
discarded and a fresh provider is built.
Returns None if the MCP SDK's OAuth support is unavailable.
"""
with self._entries_lock:
entry = self._entries.get(server_name)
if entry is not None and entry.server_url != server_url:
logger.info(
"MCP OAuth '%s': URL changed from %s to %s, discarding cache",
server_name, entry.server_url, server_url,
)
entry = None
if entry is None:
entry = _ProviderEntry(
server_url=server_url,
oauth_config=oauth_config,
)
self._entries[server_name] = entry
if entry.provider is None:
entry.provider = self._build_provider(server_name, entry)
return entry.provider
def _build_provider(
self,
server_name: str,
entry: _ProviderEntry,
) -> Optional[Any]:
"""Build the underlying OAuth provider.
Constructs :class:`HermesMCPOAuthProvider` directly using the helpers
extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow
disk-watch hook so external token refreshes (cron, other CLI
instances) are visible to running MCP sessions.
Returns None if the MCP SDK's OAuth support is unavailable.
"""
if _HERMES_PROVIDER_CLS is None:
logger.warning(
"MCP OAuth '%s': SDK auth module unavailable", server_name,
)
return None
# Local imports avoid circular deps at module import time.
from tools.mcp_oauth import (
HermesTokenStorage,
_OAUTH_AVAILABLE,
_build_client_metadata,
_configure_callback_port,
_is_interactive,
_maybe_preregister_client,
_parse_base_url,
_redirect_handler,
_wait_for_callback,
)
if not _OAUTH_AVAILABLE:
return None
cfg = dict(entry.oauth_config or {})
storage = HermesTokenStorage(server_name)
if not _is_interactive() and not storage.has_cached_tokens():
logger.warning(
"MCP OAuth for '%s': non-interactive environment and no "
"cached tokens found. Run interactively first to complete "
"initial authorization.",
server_name,
)
_configure_callback_port(cfg)
client_metadata = _build_client_metadata(cfg)
_maybe_preregister_client(storage, cfg, client_metadata)
return _HERMES_PROVIDER_CLS(
server_name=server_name,
server_url=_parse_base_url(entry.server_url),
client_metadata=client_metadata,
storage=storage,
redirect_handler=_redirect_handler,
callback_handler=_wait_for_callback,
timeout=float(cfg.get("timeout", 300)),
)
def remove(self, server_name: str) -> None:
"""Evict the provider from cache AND delete tokens from disk.
Called by ``hermes mcp remove <name>`` and (indirectly) by
``hermes mcp login <name>`` during forced re-auth.
"""
with self._entries_lock:
self._entries.pop(server_name, None)
from tools.mcp_oauth import remove_oauth_tokens
remove_oauth_tokens(server_name)
logger.info(
"MCP OAuth '%s': evicted from cache and removed from disk",
server_name,
)
# -- Disk watch ----------------------------------------------------------
async def invalidate_if_disk_changed(self, server_name: str) -> bool:
"""If the tokens file on disk has a newer mtime than last-seen, force
the MCP SDK provider to reload its in-memory state.
Returns True if the cache was invalidated (mtime differed). This is
the core fix for the external-refresh workflow: a cron job writes
fresh tokens to disk, and on the next tool call the running MCP
session picks them up without a restart.
"""
from tools.mcp_oauth import _get_token_dir, _safe_filename
entry = self._entries.get(server_name)
if entry is None or entry.provider is None:
return False
async with entry.lock:
tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json"
try:
mtime_ns = tokens_path.stat().st_mtime_ns
except (FileNotFoundError, OSError):
return False
if mtime_ns != entry.last_mtime_ns:
old = entry.last_mtime_ns
entry.last_mtime_ns = mtime_ns
# Force the SDK's OAuthClientProvider to reload from storage
# on its next auth flow. `_initialized` is private API but
# stable across the MCP SDK versions we pin (>=1.26.0).
if hasattr(entry.provider, "_initialized"):
entry.provider._initialized = False # noqa: SLF001
logger.info(
"MCP OAuth '%s': tokens file changed (mtime %d -> %d), "
"forcing reload",
server_name, old, mtime_ns,
)
return True
return False
# -- 401 handler (dedup'd) -----------------------------------------------
async def handle_401(
self,
server_name: str,
failed_access_token: Optional[str] = None,
) -> bool:
"""Handle a 401 from a tool call, deduplicated across concurrent callers.
Returns:
True if a (possibly new) access token is now available — caller
should trigger a reconnect and retry the operation.
False if no recovery path exists — caller should surface a
``needs_reauth`` error to the model so it stops hallucinating
manual refresh attempts.
Thundering-herd protection: if N concurrent tool calls hit 401 with
the same ``failed_access_token``, only one recovery attempt fires.
Others await the same future.
"""
entry = self._entries.get(server_name)
if entry is None or entry.provider is None:
return False
key = failed_access_token or "<unknown>"
loop = asyncio.get_running_loop()
async with entry.lock:
pending = entry.pending_401.get(key)
if pending is None:
pending = loop.create_future()
entry.pending_401[key] = pending
async def _do_handle() -> None:
try:
# Step 1: Did disk change? Picks up external refresh.
disk_changed = await self.invalidate_if_disk_changed(
server_name
)
if disk_changed:
if not pending.done():
pending.set_result(True)
return
# Step 2: No disk change — if the SDK can refresh
# in-place, let the caller retry. The SDK's httpx.Auth
# flow will issue the refresh on the next request.
provider = entry.provider
ctx = getattr(provider, "context", None)
can_refresh = False
if ctx is not None:
can_refresh_fn = getattr(ctx, "can_refresh_token", None)
if callable(can_refresh_fn):
try:
can_refresh = bool(can_refresh_fn())
except Exception:
can_refresh = False
if not pending.done():
pending.set_result(can_refresh)
except Exception as exc: # pragma: no cover — defensive
logger.warning(
"MCP OAuth '%s': 401 handler failed: %s",
server_name, exc,
)
if not pending.done():
pending.set_result(False)
finally:
entry.pending_401.pop(key, None)
asyncio.create_task(_do_handle())
try:
return await pending
except Exception as exc: # pragma: no cover — defensive
logger.warning(
"MCP OAuth '%s': awaiting 401 handler failed: %s",
server_name, exc,
)
return False
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_MANAGER: Optional[MCPOAuthManager] = None
_MANAGER_LOCK = threading.Lock()
def get_manager() -> MCPOAuthManager:
"""Return the process-wide :class:`MCPOAuthManager` singleton."""
global _MANAGER
with _MANAGER_LOCK:
if _MANAGER is None:
_MANAGER = MCPOAuthManager()
return _MANAGER
def reset_manager_for_tests() -> None:
"""Test-only helper: drop the singleton so fixtures start clean."""
global _MANAGER
with _MANAGER_LOCK:
_MANAGER = None