mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
* [verified] fix(mcp-oauth): bridge httpx auth_flow bidirectional generator HermesMCPOAuthProvider.async_auth_flow wrapped the SDK's auth_flow with 'async for item in super().async_auth_flow(request): yield item', which discards httpx's .asend(response) values and resumes the inner generator with None. This broke every OAuth MCP server on the first HTTP response with 'NoneType' object has no attribute 'status_code' crashing at mcp/client/auth/oauth2.py:505. Replace with a manual bridge that forwards .asend() values into the inner generator, preserving httpx's bidirectional auth_flow contract. Add tests/tools/test_mcp_oauth_bidirectional.py with two regression tests that drive the flow through real .asend() round-trips. These catch the bug at the unit level; prior tests only exercised _initialize() and disk-watching, never the full generator protocol. Verified against BetterStack MCP: Before: 'Connection failed (11564ms): NoneType...' after 3 retries After: 'Connected (2416ms); Tools discovered: 83' Regression from #11383. * [verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load PR #11383's consolidation fixed external-refresh reloading and 401 dedup but left two latent bugs that surfaced on BetterStack and any other OAuth MCP with a split-origin authorization server: 1. HermesTokenStorage persisted only a relative 'expires_in', which is meaningless after a process restart. The MCP SDK's OAuthContext does NOT seed token_expiry_time in _initialize, so is_token_valid() returned True for any reloaded token regardless of age. Expired tokens shipped to servers, and app-level auth failures (e.g. BetterStack's 'No teams found. Please check your authentication.') were invisible to the transport-layer 401 handler. 2. Even once preemptive refresh did fire, the SDK's _refresh_token falls back to {server_url}/token when oauth_metadata isn't cached. For providers whose AS is at a different origin (BetterStack: mcp.betterstack.com for MCP, betterstack.com/oauth/token for the token endpoint), that fallback 404s and drops into full browser re-auth on every process restart. Fix set: - HermesTokenStorage.set_tokens persists an absolute wall-clock expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL at write time). - HermesTokenStorage.get_tokens reconstructs expires_in from max(expires_at - now, 0), clamping expired tokens to zero TTL. Legacy files without expires_at fall back to file-mtime as a best-effort wall-clock proxy, self-healing on the next set_tokens. - HermesMCPOAuthProvider._initialize calls super(), then update_token_expiry on the reloaded tokens so token_expiry_time reflects actual remaining TTL. If tokens are loaded but oauth_metadata is missing, pre-flight PRM + ASM discovery runs via httpx.AsyncClient using the MCP SDK's own URL builders and response handlers (build_protected_resource_metadata_discovery_urls, handle_auth_metadata_response, etc.) so the SDK sees the correct token_endpoint before the first refresh attempt. Pre-flight is skipped when there are no stored tokens to keep fresh-install paths zero-cost. Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py): - set_tokens persists absolute expires_at - set_tokens skips expires_at when token has no expires_in - get_tokens round-trips expires_at -> remaining expires_in - expired tokens reload with expires_in=0 - legacy files without expires_at fall back to mtime proxy - _initialize seeds token_expiry_time from stored tokens - _initialize flags expired-on-disk tokens as is_token_valid=False - _initialize pre-flights PRM + ASM discovery with mock transport - _initialize skips pre-flight when no tokens are stored Verified against BetterStack MCP: hermes mcp test betterstack -> Connected (2508ms), 83 tools mcp_betterstack_telemetry_list_teams_tool -> real team data, not 'No teams found. Please check your authentication.' Reference: mcp-oauth-token-diagnosis skill, Fix A. * chore: map hermes@noushq.ai to benbarclay in AUTHOR_MAP Needed for CI attribution check on cherry-picked commits from PR #12025. --------- Co-authored-by: Hermes Agent <hermes@noushq.ai>
557 lines
23 KiB
Python
557 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""Central manager for per-server MCP OAuth state.
|
|
|
|
One instance shared across the process. Holds per-server OAuth provider
|
|
instances and coordinates:
|
|
|
|
- **Cross-process token reload** via mtime-based disk watch. When an external
|
|
process (e.g. a user cron job) refreshes tokens on disk, the next auth flow
|
|
picks them up without requiring a process restart.
|
|
- **401 deduplication** via in-flight futures. When N concurrent tool calls
|
|
all hit 401 with the same access_token, only one recovery attempt fires;
|
|
the rest await the same result.
|
|
- **Reconnect signalling** for long-lived MCP sessions. The manager itself
|
|
does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does —
|
|
but the manager is the single source of truth that decides when reconnect
|
|
is warranted.
|
|
|
|
Replaces what used to be scattered across eight call sites in `mcp_oauth.py`,
|
|
`mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place
|
|
that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths
|
|
go through `get_manager()`.
|
|
|
|
Design reference:
|
|
|
|
- Claude Code's ``invalidateOAuthCacheIfDiskChanged``
|
|
(``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical
|
|
external-refresh staleness bug class.
|
|
- Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed``
|
|
(``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's
|
|
lazy refresh rather than calling refresh before every op, because one
|
|
``stat()`` per tool call is cheaper than an ``await`` + potential refresh
|
|
round-trip, and the SDK's in-memory expiry path is already correct.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Per-server entry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class _ProviderEntry:
|
|
"""Per-server OAuth state tracked by the manager.
|
|
|
|
Fields:
|
|
server_url: The MCP server URL used to build the provider. Tracked
|
|
so we can discard a cached provider if the URL changes.
|
|
oauth_config: Optional dict from ``mcp_servers.<name>.oauth``.
|
|
provider: The ``httpx.Auth``-compatible provider wrapping the MCP
|
|
SDK. None until first use.
|
|
last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file.
|
|
Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed`
|
|
to detect external refreshes.
|
|
lock: Serialises concurrent access to this entry's state. Bound to
|
|
whichever asyncio loop first awaits it (the MCP event loop).
|
|
pending_401: In-flight 401-handler futures keyed by the failed
|
|
access_token, for deduplicating thundering-herd 401s. Mirrors
|
|
Claude Code's ``pending401Handlers`` map.
|
|
"""
|
|
|
|
server_url: str
|
|
oauth_config: Optional[dict]
|
|
provider: Optional[Any] = None
|
|
last_mtime_ns: int = 0
|
|
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_hermes_provider_class() -> Optional[type]:
|
|
"""Lazy-import the SDK base class and return our subclass.
|
|
|
|
Wrapped in a function so this module imports cleanly even when the
|
|
MCP SDK's OAuth module is unavailable (e.g. older mcp versions).
|
|
"""
|
|
try:
|
|
from mcp.client.auth.oauth2 import OAuthClientProvider
|
|
except ImportError: # pragma: no cover — SDK required in CI
|
|
return None
|
|
|
|
class HermesMCPOAuthProvider(OAuthClientProvider):
|
|
"""OAuthClientProvider with pre-flow disk-mtime reload.
|
|
|
|
Before every ``async_auth_flow`` invocation, asks the manager to
|
|
check whether the tokens file on disk has been modified externally.
|
|
If so, the manager resets ``_initialized`` so the next flow
|
|
re-reads from storage.
|
|
|
|
This makes external-process refreshes (cron, another CLI instance)
|
|
visible to the running MCP session without requiring a restart.
|
|
|
|
Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged``
|
|
(``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
|
|
"""
|
|
|
|
def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
|
|
super().__init__(*args, **kwargs)
|
|
self._hermes_server_name = server_name
|
|
|
|
async def _initialize(self) -> None:
|
|
"""Load stored tokens + client info AND seed token_expiry_time.
|
|
|
|
Also eagerly fetches OAuth authorization-server metadata (PRM +
|
|
ASM) when we have stored tokens but no cached metadata, so the
|
|
SDK's ``_refresh_token`` can build the correct token_endpoint
|
|
URL on the preemptive-refresh path. Without this, the SDK
|
|
falls back to ``{mcp_server_url}/token`` (wrong for providers
|
|
whose AS is a different origin — BetterStack's MCP lives at
|
|
``https://mcp.betterstack.com`` but its token endpoint is at
|
|
``https://betterstack.com/oauth/token``), the refresh 404s, and
|
|
we drop through to full browser reauth.
|
|
|
|
The SDK's base ``_initialize`` populates ``current_tokens`` but
|
|
does NOT call ``update_token_expiry``, so ``token_expiry_time``
|
|
stays ``None`` and ``is_token_valid()`` returns True for any
|
|
loaded token regardless of actual age. After a process restart
|
|
this ships stale Bearer tokens to the server; some providers
|
|
return HTTP 401 (caught by the 401 handler), others return 200
|
|
with an app-level auth error (invisible to the transport layer,
|
|
e.g. BetterStack returning "No teams found. Please check your
|
|
authentication.").
|
|
|
|
Seeding ``token_expiry_time`` from the reloaded token fixes that:
|
|
``is_token_valid()`` correctly reports False for expired tokens,
|
|
``async_auth_flow`` takes the ``can_refresh_token()`` branch,
|
|
and the SDK quietly refreshes before the first real request.
|
|
|
|
Paired with :class:`HermesTokenStorage` persisting an absolute
|
|
``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the
|
|
remaining TTL we compute here reflects real wall-clock age.
|
|
"""
|
|
await super()._initialize()
|
|
tokens = self.context.current_tokens
|
|
if tokens is not None and tokens.expires_in is not None:
|
|
self.context.update_token_expiry(tokens)
|
|
|
|
# Pre-flight OAuth AS discovery so ``_refresh_token`` has a
|
|
# correct ``token_endpoint`` before the first refresh attempt.
|
|
# Only runs when we have tokens on cold-load but no cached
|
|
# metadata — i.e. the exact scenario where the SDK's built-in
|
|
# 401-branch discovery hasn't had a chance to run yet.
|
|
if (
|
|
tokens is not None
|
|
and self.context.oauth_metadata is None
|
|
):
|
|
try:
|
|
await self._prefetch_oauth_metadata()
|
|
except Exception as exc: # pragma: no cover — defensive
|
|
# Non-fatal: if discovery fails, the SDK's normal 401-
|
|
# branch discovery will run on the next request.
|
|
logger.debug(
|
|
"MCP OAuth '%s': pre-flight metadata discovery "
|
|
"failed (non-fatal): %s",
|
|
self._hermes_server_name, exc,
|
|
)
|
|
|
|
async def _prefetch_oauth_metadata(self) -> None:
|
|
"""Fetch PRM + ASM from the well-known endpoints, cache on context.
|
|
|
|
Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551)
|
|
but runs synchronously before the first request instead of
|
|
inside the httpx auth_flow generator. Uses the SDK's own URL
|
|
builders and response handlers so we track whatever the SDK
|
|
version we're pinned to expects.
|
|
"""
|
|
import httpx # local import: httpx is an MCP SDK dependency
|
|
from mcp.client.auth.utils import (
|
|
build_oauth_authorization_server_metadata_discovery_urls,
|
|
build_protected_resource_metadata_discovery_urls,
|
|
create_oauth_metadata_request,
|
|
handle_auth_metadata_response,
|
|
handle_protected_resource_response,
|
|
)
|
|
|
|
server_url = self.context.server_url
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
# Step 1: PRM discovery to learn the authorization_server URL.
|
|
for url in build_protected_resource_metadata_discovery_urls(
|
|
None, server_url
|
|
):
|
|
req = create_oauth_metadata_request(url)
|
|
try:
|
|
resp = await client.send(req)
|
|
except httpx.HTTPError as exc:
|
|
logger.debug(
|
|
"MCP OAuth '%s': PRM discovery to %s failed: %s",
|
|
self._hermes_server_name, url, exc,
|
|
)
|
|
continue
|
|
prm = await handle_protected_resource_response(resp)
|
|
if prm:
|
|
self.context.protected_resource_metadata = prm
|
|
if prm.authorization_servers:
|
|
self.context.auth_server_url = str(
|
|
prm.authorization_servers[0]
|
|
)
|
|
break
|
|
|
|
# Step 2: ASM discovery against the auth_server_url (or
|
|
# server_url fallback for legacy providers).
|
|
for url in build_oauth_authorization_server_metadata_discovery_urls(
|
|
self.context.auth_server_url, server_url
|
|
):
|
|
req = create_oauth_metadata_request(url)
|
|
try:
|
|
resp = await client.send(req)
|
|
except httpx.HTTPError as exc:
|
|
logger.debug(
|
|
"MCP OAuth '%s': ASM discovery to %s failed: %s",
|
|
self._hermes_server_name, url, exc,
|
|
)
|
|
continue
|
|
ok, asm = await handle_auth_metadata_response(resp)
|
|
if not ok:
|
|
break
|
|
if asm:
|
|
self.context.oauth_metadata = asm
|
|
logger.debug(
|
|
"MCP OAuth '%s': pre-flight ASM discovered "
|
|
"token_endpoint=%s",
|
|
self._hermes_server_name, asm.token_endpoint,
|
|
)
|
|
break
|
|
|
|
async def async_auth_flow(self, request): # type: ignore[override]
|
|
# Pre-flow hook: ask the manager to refresh from disk if needed.
|
|
# Any failure here is non-fatal — we just log and proceed with
|
|
# whatever state the SDK already has.
|
|
try:
|
|
await get_manager().invalidate_if_disk_changed(
|
|
self._hermes_server_name
|
|
)
|
|
except Exception as exc: # pragma: no cover — defensive
|
|
logger.debug(
|
|
"MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s",
|
|
self._hermes_server_name, exc,
|
|
)
|
|
|
|
# Manually bridge the bidirectional generator protocol. httpx's
|
|
# auth_flow driver (httpx._client._send_handling_auth) calls
|
|
# ``auth_flow.asend(response)`` to feed HTTP responses back into
|
|
# the generator. A naive wrapper using ``async for item in inner:
|
|
# yield item`` DISCARDS those .asend(response) values and resumes
|
|
# the inner generator with None, so the SDK's
|
|
# ``response = yield request`` branch in
|
|
# mcp/client/auth/oauth2.py sees response=None and crashes at
|
|
# ``if response.status_code == 401`` with AttributeError.
|
|
#
|
|
# The bridge below forwards each .asend() value into the inner
|
|
# generator via inner.asend(incoming), preserving the bidirectional
|
|
# contract. Regression from PR #11383 caught by
|
|
# tests/tools/test_mcp_oauth_bidirectional.py.
|
|
inner = super().async_auth_flow(request)
|
|
try:
|
|
outgoing = await inner.__anext__()
|
|
while True:
|
|
incoming = yield outgoing
|
|
outgoing = await inner.asend(incoming)
|
|
except StopAsyncIteration:
|
|
return
|
|
|
|
return HermesMCPOAuthProvider
|
|
|
|
|
|
# Cached at import time. Tested and used by :class:`MCPOAuthManager`.
|
|
_HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Manager
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class MCPOAuthManager:
|
|
"""Single source of truth for per-server MCP OAuth state.
|
|
|
|
Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for
|
|
get-or-create semantics. Per-entry state is guarded by the entry's own
|
|
``asyncio.Lock`` (used from the MCP event loop thread).
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._entries: dict[str, _ProviderEntry] = {}
|
|
self._entries_lock = threading.Lock()
|
|
|
|
# -- Provider construction / caching -------------------------------------
|
|
|
|
def get_or_build_provider(
|
|
self,
|
|
server_name: str,
|
|
server_url: str,
|
|
oauth_config: Optional[dict],
|
|
) -> Optional[Any]:
|
|
"""Return a cached OAuth provider for ``server_name`` or build one.
|
|
|
|
Idempotent: repeat calls with the same name return the same instance.
|
|
If ``server_url`` changes for a given name, the cached entry is
|
|
discarded and a fresh provider is built.
|
|
|
|
Returns None if the MCP SDK's OAuth support is unavailable.
|
|
"""
|
|
with self._entries_lock:
|
|
entry = self._entries.get(server_name)
|
|
if entry is not None and entry.server_url != server_url:
|
|
logger.info(
|
|
"MCP OAuth '%s': URL changed from %s to %s, discarding cache",
|
|
server_name, entry.server_url, server_url,
|
|
)
|
|
entry = None
|
|
|
|
if entry is None:
|
|
entry = _ProviderEntry(
|
|
server_url=server_url,
|
|
oauth_config=oauth_config,
|
|
)
|
|
self._entries[server_name] = entry
|
|
|
|
if entry.provider is None:
|
|
entry.provider = self._build_provider(server_name, entry)
|
|
|
|
return entry.provider
|
|
|
|
def _build_provider(
|
|
self,
|
|
server_name: str,
|
|
entry: _ProviderEntry,
|
|
) -> Optional[Any]:
|
|
"""Build the underlying OAuth provider.
|
|
|
|
Constructs :class:`HermesMCPOAuthProvider` directly using the helpers
|
|
extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow
|
|
disk-watch hook so external token refreshes (cron, other CLI
|
|
instances) are visible to running MCP sessions.
|
|
|
|
Returns None if the MCP SDK's OAuth support is unavailable.
|
|
"""
|
|
if _HERMES_PROVIDER_CLS is None:
|
|
logger.warning(
|
|
"MCP OAuth '%s': SDK auth module unavailable", server_name,
|
|
)
|
|
return None
|
|
|
|
# Local imports avoid circular deps at module import time.
|
|
from tools.mcp_oauth import (
|
|
HermesTokenStorage,
|
|
_OAUTH_AVAILABLE,
|
|
_build_client_metadata,
|
|
_configure_callback_port,
|
|
_is_interactive,
|
|
_maybe_preregister_client,
|
|
_parse_base_url,
|
|
_redirect_handler,
|
|
_wait_for_callback,
|
|
)
|
|
|
|
if not _OAUTH_AVAILABLE:
|
|
return None
|
|
|
|
cfg = dict(entry.oauth_config or {})
|
|
storage = HermesTokenStorage(server_name)
|
|
|
|
if not _is_interactive() and not storage.has_cached_tokens():
|
|
logger.warning(
|
|
"MCP OAuth for '%s': non-interactive environment and no "
|
|
"cached tokens found. Run interactively first to complete "
|
|
"initial authorization.",
|
|
server_name,
|
|
)
|
|
|
|
_configure_callback_port(cfg)
|
|
client_metadata = _build_client_metadata(cfg)
|
|
_maybe_preregister_client(storage, cfg, client_metadata)
|
|
|
|
return _HERMES_PROVIDER_CLS(
|
|
server_name=server_name,
|
|
server_url=_parse_base_url(entry.server_url),
|
|
client_metadata=client_metadata,
|
|
storage=storage,
|
|
redirect_handler=_redirect_handler,
|
|
callback_handler=_wait_for_callback,
|
|
timeout=float(cfg.get("timeout", 300)),
|
|
)
|
|
|
|
def remove(self, server_name: str) -> None:
|
|
"""Evict the provider from cache AND delete tokens from disk.
|
|
|
|
Called by ``hermes mcp remove <name>`` and (indirectly) by
|
|
``hermes mcp login <name>`` during forced re-auth.
|
|
"""
|
|
with self._entries_lock:
|
|
self._entries.pop(server_name, None)
|
|
|
|
from tools.mcp_oauth import remove_oauth_tokens
|
|
remove_oauth_tokens(server_name)
|
|
logger.info(
|
|
"MCP OAuth '%s': evicted from cache and removed from disk",
|
|
server_name,
|
|
)
|
|
|
|
# -- Disk watch ----------------------------------------------------------
|
|
|
|
async def invalidate_if_disk_changed(self, server_name: str) -> bool:
|
|
"""If the tokens file on disk has a newer mtime than last-seen, force
|
|
the MCP SDK provider to reload its in-memory state.
|
|
|
|
Returns True if the cache was invalidated (mtime differed). This is
|
|
the core fix for the external-refresh workflow: a cron job writes
|
|
fresh tokens to disk, and on the next tool call the running MCP
|
|
session picks them up without a restart.
|
|
"""
|
|
from tools.mcp_oauth import _get_token_dir, _safe_filename
|
|
|
|
entry = self._entries.get(server_name)
|
|
if entry is None or entry.provider is None:
|
|
return False
|
|
|
|
async with entry.lock:
|
|
tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json"
|
|
try:
|
|
mtime_ns = tokens_path.stat().st_mtime_ns
|
|
except (FileNotFoundError, OSError):
|
|
return False
|
|
|
|
if mtime_ns != entry.last_mtime_ns:
|
|
old = entry.last_mtime_ns
|
|
entry.last_mtime_ns = mtime_ns
|
|
# Force the SDK's OAuthClientProvider to reload from storage
|
|
# on its next auth flow. `_initialized` is private API but
|
|
# stable across the MCP SDK versions we pin (>=1.26.0).
|
|
if hasattr(entry.provider, "_initialized"):
|
|
entry.provider._initialized = False # noqa: SLF001
|
|
logger.info(
|
|
"MCP OAuth '%s': tokens file changed (mtime %d -> %d), "
|
|
"forcing reload",
|
|
server_name, old, mtime_ns,
|
|
)
|
|
return True
|
|
return False
|
|
|
|
# -- 401 handler (dedup'd) -----------------------------------------------
|
|
|
|
async def handle_401(
|
|
self,
|
|
server_name: str,
|
|
failed_access_token: Optional[str] = None,
|
|
) -> bool:
|
|
"""Handle a 401 from a tool call, deduplicated across concurrent callers.
|
|
|
|
Returns:
|
|
True if a (possibly new) access token is now available — caller
|
|
should trigger a reconnect and retry the operation.
|
|
False if no recovery path exists — caller should surface a
|
|
``needs_reauth`` error to the model so it stops hallucinating
|
|
manual refresh attempts.
|
|
|
|
Thundering-herd protection: if N concurrent tool calls hit 401 with
|
|
the same ``failed_access_token``, only one recovery attempt fires.
|
|
Others await the same future.
|
|
"""
|
|
entry = self._entries.get(server_name)
|
|
if entry is None or entry.provider is None:
|
|
return False
|
|
|
|
key = failed_access_token or "<unknown>"
|
|
loop = asyncio.get_running_loop()
|
|
|
|
async with entry.lock:
|
|
pending = entry.pending_401.get(key)
|
|
if pending is None:
|
|
pending = loop.create_future()
|
|
entry.pending_401[key] = pending
|
|
|
|
async def _do_handle() -> None:
|
|
try:
|
|
# Step 1: Did disk change? Picks up external refresh.
|
|
disk_changed = await self.invalidate_if_disk_changed(
|
|
server_name
|
|
)
|
|
if disk_changed:
|
|
if not pending.done():
|
|
pending.set_result(True)
|
|
return
|
|
|
|
# Step 2: No disk change — if the SDK can refresh
|
|
# in-place, let the caller retry. The SDK's httpx.Auth
|
|
# flow will issue the refresh on the next request.
|
|
provider = entry.provider
|
|
ctx = getattr(provider, "context", None)
|
|
can_refresh = False
|
|
if ctx is not None:
|
|
can_refresh_fn = getattr(ctx, "can_refresh_token", None)
|
|
if callable(can_refresh_fn):
|
|
try:
|
|
can_refresh = bool(can_refresh_fn())
|
|
except Exception:
|
|
can_refresh = False
|
|
if not pending.done():
|
|
pending.set_result(can_refresh)
|
|
except Exception as exc: # pragma: no cover — defensive
|
|
logger.warning(
|
|
"MCP OAuth '%s': 401 handler failed: %s",
|
|
server_name, exc,
|
|
)
|
|
if not pending.done():
|
|
pending.set_result(False)
|
|
finally:
|
|
entry.pending_401.pop(key, None)
|
|
|
|
asyncio.create_task(_do_handle())
|
|
|
|
try:
|
|
return await pending
|
|
except Exception as exc: # pragma: no cover — defensive
|
|
logger.warning(
|
|
"MCP OAuth '%s': awaiting 401 handler failed: %s",
|
|
server_name, exc,
|
|
)
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level singleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
_MANAGER: Optional[MCPOAuthManager] = None
|
|
_MANAGER_LOCK = threading.Lock()
|
|
|
|
|
|
def get_manager() -> MCPOAuthManager:
|
|
"""Return the process-wide :class:`MCPOAuthManager` singleton."""
|
|
global _MANAGER
|
|
with _MANAGER_LOCK:
|
|
if _MANAGER is None:
|
|
_MANAGER = MCPOAuthManager()
|
|
return _MANAGER
|
|
|
|
|
|
def reset_manager_for_tests() -> None:
|
|
"""Test-only helper: drop the singleton so fixtures start clean."""
|
|
global _MANAGER
|
|
with _MANAGER_LOCK:
|
|
_MANAGER = None
|