hermes-agent/tools/mcp_oauth.py
Teknium a3a4932405
fix(mcp-oauth): bidirectional auth_flow bridge + absolute expires_at (salvage #12025) (#12717)
* [verified] fix(mcp-oauth): bridge httpx auth_flow bidirectional generator

HermesMCPOAuthProvider.async_auth_flow wrapped the SDK's auth_flow with
'async for item in super().async_auth_flow(request): yield item', which
discards httpx's .asend(response) values and resumes the inner generator
with None. This broke every OAuth MCP server on the first HTTP response
with 'NoneType' object has no attribute 'status_code' crashing at
mcp/client/auth/oauth2.py:505.

Replace with a manual bridge that forwards .asend() values into the
inner generator, preserving httpx's bidirectional auth_flow contract.

Add tests/tools/test_mcp_oauth_bidirectional.py with two regression
tests that drive the flow through real .asend() round-trips. These
catch the bug at the unit level; prior tests only exercised
_initialize() and disk-watching, never the full generator protocol.

Verified against BetterStack MCP:
  Before: 'Connection failed (11564ms): NoneType...' after 3 retries
  After:  'Connected (2416ms); Tools discovered: 83'

Regression from #11383.

* [verified] fix(mcp-oauth): seed token_expiry_time + pre-flight AS discovery on cold-load

PR #11383's consolidation fixed external-refresh reloading and 401 dedup
but left two latent bugs that surfaced on BetterStack and any other OAuth
MCP with a split-origin authorization server:

1. HermesTokenStorage persisted only a relative 'expires_in', which is
   meaningless after a process restart. The MCP SDK's OAuthContext
   does NOT seed token_expiry_time in _initialize, so is_token_valid()
   returned True for any reloaded token regardless of age. Expired
   tokens shipped to servers, and app-level auth failures (e.g.
   BetterStack's 'No teams found. Please check your authentication.')
   were invisible to the transport-layer 401 handler.

2. Even once preemptive refresh did fire, the SDK's _refresh_token
   falls back to {server_url}/token when oauth_metadata isn't cached.
   For providers whose AS is at a different origin (BetterStack:
   mcp.betterstack.com for MCP, betterstack.com/oauth/token for the
   token endpoint), that fallback 404s and drops into full browser
   re-auth on every process restart.

Fix set:

- HermesTokenStorage.set_tokens persists an absolute wall-clock
  expires_at alongside the SDK's OAuthToken JSON (time.time() + TTL
  at write time).
- HermesTokenStorage.get_tokens reconstructs expires_in from
  max(expires_at - now, 0), clamping expired tokens to zero TTL.
  Legacy files without expires_at fall back to file-mtime as a
  best-effort wall-clock proxy, self-healing on the next set_tokens.
- HermesMCPOAuthProvider._initialize calls super(), then
  update_token_expiry on the reloaded tokens so token_expiry_time
  reflects actual remaining TTL. If tokens are loaded but
  oauth_metadata is missing, pre-flight PRM + ASM discovery runs
  via httpx.AsyncClient using the MCP SDK's own URL builders and
  response handlers (build_protected_resource_metadata_discovery_urls,
  handle_auth_metadata_response, etc.) so the SDK sees the correct
  token_endpoint before the first refresh attempt. Pre-flight is
  skipped when there are no stored tokens to keep fresh-install
  paths zero-cost.

Test coverage (tests/tools/test_mcp_oauth_cold_load_expiry.py):
- set_tokens persists absolute expires_at
- set_tokens skips expires_at when token has no expires_in
- get_tokens round-trips expires_at -> remaining expires_in
- expired tokens reload with expires_in=0
- legacy files without expires_at fall back to mtime proxy
- _initialize seeds token_expiry_time from stored tokens
- _initialize flags expired-on-disk tokens as is_token_valid=False
- _initialize pre-flights PRM + ASM discovery with mock transport
- _initialize skips pre-flight when no tokens are stored

Verified against BetterStack MCP:
  hermes mcp test betterstack -> Connected (2508ms), 83 tools
  mcp_betterstack_telemetry_list_teams_tool -> real team data, not
    'No teams found. Please check your authentication.'

Reference: mcp-oauth-token-diagnosis skill, Fix A.

* chore: map hermes@noushq.ai to benbarclay in AUTHOR_MAP

Needed for CI attribution check on cherry-picked commits from PR #12025.

---------

Co-authored-by: Hermes Agent <hermes@noushq.ai>
2026-04-19 16:31:07 -07:00

572 lines
21 KiB
Python

#!/usr/bin/env python3
"""
MCP OAuth 2.1 Client Support
Implements the browser-based OAuth 2.1 authorization code flow with PKCE
for MCP servers that require OAuth authentication instead of static bearer
tokens.
Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass)
which handles discovery, dynamic client registration, PKCE, token exchange,
refresh, and step-up authorization automatically.
This module provides the glue:
- ``HermesTokenStorage``: persists tokens/client-info to disk so they
survive across process restarts.
- Callback server: ephemeral localhost HTTP server to capture the OAuth
redirect with the authorization code.
- ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires
everything together and returns the ``httpx.Auth`` object.
Configuration in config.yaml::
mcp_servers:
my_server:
url: "https://mcp.example.com/mcp"
auth: oauth
oauth: # all fields optional
client_id: "pre-registered-id" # skip dynamic registration
client_secret: "secret" # confidential clients only
scope: "read write" # default: server-provided
redirect_port: 0 # 0 = auto-pick free port
client_name: "My Custom Client" # default: "Hermes Agent"
"""
import asyncio
import json
import logging
import os
import re
import socket
import sys
import threading
import time
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy imports -- MCP SDK with OAuth support is optional
# ---------------------------------------------------------------------------
_OAUTH_AVAILABLE = False
try:
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthToken,
)
from pydantic import AnyUrl
_OAUTH_AVAILABLE = True
except ImportError:
logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled")
# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------
class OAuthNonInteractiveError(RuntimeError):
"""Raised when OAuth requires browser interaction in a non-interactive env."""
# ---------------------------------------------------------------------------
# Module-level state
# ---------------------------------------------------------------------------
# Port used by the most recent build_oauth_auth() call. Exposed so that
# tests can verify the callback server and the redirect_uri share a port.
_oauth_port: int | None = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_token_dir() -> Path:
"""Return the directory for MCP OAuth token files.
Uses HERMES_HOME so each profile gets its own OAuth tokens.
Layout: ``HERMES_HOME/mcp-tokens/``
"""
try:
from hermes_constants import get_hermes_home
base = Path(get_hermes_home())
except ImportError:
base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes")))
return base / "mcp-tokens"
def _safe_filename(name: str) -> str:
"""Sanitize a server name for use as a filename (no path separators)."""
return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default"
def _find_free_port() -> int:
"""Find an available TCP port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _is_interactive() -> bool:
"""Return True if we can reasonably expect to interact with a user."""
try:
return sys.stdin.isatty()
except (AttributeError, ValueError):
return False
def _can_open_browser() -> bool:
"""Return True if opening a browser is likely to work."""
# Explicit SSH session → no local display
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
return False
# macOS and Windows usually have a display
if os.name == "nt":
return True
try:
if os.uname().sysname == "Darwin":
return True
except AttributeError:
pass
# Linux/other posix: need DISPLAY or WAYLAND_DISPLAY
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
return True
return False
def _read_json(path: Path) -> dict | None:
"""Read a JSON file, returning None if it doesn't exist or is invalid."""
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError) as exc:
logger.warning("Failed to read %s: %s", path, exc)
return None
def _write_json(path: Path, data: dict) -> None:
"""Write a dict as JSON with restricted permissions (0o600)."""
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(".tmp")
try:
tmp.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8")
os.chmod(tmp, 0o600)
tmp.rename(path)
except OSError:
tmp.unlink(missing_ok=True)
raise
# ---------------------------------------------------------------------------
# HermesTokenStorage -- persistent token/client-info on disk
# ---------------------------------------------------------------------------
class HermesTokenStorage:
"""Persist OAuth tokens and client registration to JSON files.
File layout::
HERMES_HOME/mcp-tokens/<server_name>.json -- tokens
HERMES_HOME/mcp-tokens/<server_name>.client.json -- client info
"""
def __init__(self, server_name: str):
self._server_name = _safe_filename(server_name)
def _tokens_path(self) -> Path:
return _get_token_dir() / f"{self._server_name}.json"
def _client_info_path(self) -> Path:
return _get_token_dir() / f"{self._server_name}.client.json"
# -- tokens ------------------------------------------------------------
async def get_tokens(self) -> "OAuthToken | None":
data = _read_json(self._tokens_path())
if data is None:
return None
# Hermes records an absolute wall-clock ``expires_at`` alongside the
# SDK's serialized token (see ``set_tokens``). On read we rewrite
# ``expires_in`` to the remaining seconds so the SDK's downstream
# ``update_token_expiry`` computes the correct absolute time and
# ``is_token_valid()`` correctly reports False for tokens that
# expired while the process was down.
#
# Legacy token files (pre-Fix-A) have ``expires_in`` but no
# ``expires_at``. We fall back to the file's mtime as a best-effort
# wall-clock proxy for when the token was written: if (mtime +
# expires_in) is in the past, clamp ``expires_in`` to zero so the
# SDK refreshes before the first request. This self-heals one-time
# on the next successful ``set_tokens``, which writes the new
# ``expires_at`` field. The stored ``expires_at`` is stripped before
# model_validate because it's not part of the SDK's OAuthToken schema.
absolute_expiry = data.pop("expires_at", None)
if absolute_expiry is not None:
data["expires_in"] = int(max(absolute_expiry - time.time(), 0))
elif data.get("expires_in") is not None:
try:
file_mtime = self._tokens_path().stat().st_mtime
except OSError:
file_mtime = None
if file_mtime is not None:
try:
implied_expiry = file_mtime + int(data["expires_in"])
data["expires_in"] = int(max(implied_expiry - time.time(), 0))
except (TypeError, ValueError):
pass
try:
return OAuthToken.model_validate(data)
except (ValueError, TypeError, KeyError) as exc:
logger.warning("Corrupt tokens at %s -- ignoring: %s", self._tokens_path(), exc)
return None
async def set_tokens(self, tokens: "OAuthToken") -> None:
payload = tokens.model_dump(exclude_none=True)
# Persist an absolute ``expires_at`` so a process restart can
# reconstruct the correct remaining TTL. Without this the MCP SDK's
# ``_initialize`` reloads a relative ``expires_in`` which has no
# wall-clock reference, leaving ``context.token_expiry_time=None``
# and ``is_token_valid()`` falsely reporting True. See Fix A in
# ``mcp-oauth-token-diagnosis`` skill + Claude Code's
# ``OAuthTokens.expiresAt`` persistence (auth.ts ~180).
expires_in = payload.get("expires_in")
if expires_in is not None:
try:
payload["expires_at"] = time.time() + int(expires_in)
except (TypeError, ValueError):
# Mock tokens or unusual shapes: skip the expires_at write
# rather than fail persistence.
pass
_write_json(self._tokens_path(), payload)
logger.debug("OAuth tokens saved for %s", self._server_name)
# -- client info -------------------------------------------------------
async def get_client_info(self) -> "OAuthClientInformationFull | None":
data = _read_json(self._client_info_path())
if data is None:
return None
try:
return OAuthClientInformationFull.model_validate(data)
except (ValueError, TypeError, KeyError) as exc:
logger.warning("Corrupt client info at %s -- ignoring: %s", self._client_info_path(), exc)
return None
async def set_client_info(self, client_info: "OAuthClientInformationFull") -> None:
_write_json(self._client_info_path(), client_info.model_dump(exclude_none=True))
logger.debug("OAuth client info saved for %s", self._server_name)
# -- cleanup -----------------------------------------------------------
def remove(self) -> None:
"""Delete all stored OAuth state for this server."""
for p in (self._tokens_path(), self._client_info_path()):
p.unlink(missing_ok=True)
def has_cached_tokens(self) -> bool:
"""Return True if we have tokens on disk (may be expired)."""
return self._tokens_path().exists()
# ---------------------------------------------------------------------------
# Callback handler factory -- each invocation gets its own result dict
# ---------------------------------------------------------------------------
def _make_callback_handler() -> tuple[type, dict]:
"""Create a per-flow callback HTTP handler class with its own result dict.
Returns ``(HandlerClass, result_dict)`` where *result_dict* is a mutable
dict that the handler writes ``auth_code`` and ``state`` into when the
OAuth redirect arrives. Each call returns a fresh pair so concurrent
flows don't stomp on each other.
"""
result: dict[str, Any] = {"auth_code": None, "state": None, "error": None}
class _Handler(BaseHTTPRequestHandler):
def do_GET(self) -> None: # noqa: N802
params = parse_qs(urlparse(self.path).query)
code = params.get("code", [None])[0]
state = params.get("state", [None])[0]
error = params.get("error", [None])[0]
result["auth_code"] = code
result["state"] = state
result["error"] = error
body = (
"<html><body><h2>Authorization Successful</h2>"
"<p>You can close this tab and return to Hermes.</p></body></html>"
) if code else (
"<html><body><h2>Authorization Failed</h2>"
f"<p>Error: {error or 'unknown'}</p></body></html>"
)
self.send_response(200)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.end_headers()
self.wfile.write(body.encode())
def log_message(self, fmt: str, *args: Any) -> None:
logger.debug("OAuth callback: %s", fmt % args)
return _Handler, result
# ---------------------------------------------------------------------------
# Async redirect + callback handlers for OAuthClientProvider
# ---------------------------------------------------------------------------
async def _redirect_handler(authorization_url: str) -> None:
"""Show the authorization URL to the user.
Opens the browser automatically when possible; always prints the URL
as a fallback for headless/SSH/gateway environments.
"""
msg = (
f"\n MCP OAuth: authorization required.\n"
f" Open this URL in your browser:\n\n"
f" {authorization_url}\n"
)
print(msg, file=sys.stderr)
if _can_open_browser():
try:
opened = webbrowser.open(authorization_url)
if opened:
print(" (Browser opened automatically.)\n", file=sys.stderr)
else:
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
except Exception:
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
else:
print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr)
async def _wait_for_callback() -> tuple[str, str | None]:
"""Wait for the OAuth callback to arrive on the local callback server.
Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth``
before this is ever called. Polls for the result without blocking the
event loop.
Raises:
OAuthNonInteractiveError: If the callback times out (no user present
to complete the browser auth).
"""
assert _oauth_port is not None, "OAuth callback port not set"
# The callback server is already running (started in build_oauth_auth).
# We just need to poll for the result.
handler_cls, result = _make_callback_handler()
# Start a temporary server on the known port
try:
server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls)
except OSError:
# Port already in use — the server from build_oauth_auth is running.
# Fall back to polling the server started by build_oauth_auth.
raise OAuthNonInteractiveError(
"OAuth callback timed out — could not bind callback port. "
"Complete the authorization in a browser first, then retry."
)
server_thread = threading.Thread(target=server.handle_request, daemon=True)
server_thread.start()
timeout = 300.0
poll_interval = 0.5
elapsed = 0.0
try:
while elapsed < timeout:
if result["auth_code"] is not None or result["error"] is not None:
break
await asyncio.sleep(poll_interval)
elapsed += poll_interval
finally:
server.server_close()
if result["error"]:
raise RuntimeError(f"OAuth authorization failed: {result['error']}")
if result["auth_code"] is None:
raise OAuthNonInteractiveError(
"OAuth callback timed out — no authorization code received. "
"Ensure you completed the browser authorization flow."
)
return result["auth_code"], result["state"]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def remove_oauth_tokens(server_name: str) -> None:
"""Delete stored OAuth tokens and client info for a server."""
storage = HermesTokenStorage(server_name)
storage.remove()
logger.info("OAuth tokens removed for '%s'", server_name)
# ---------------------------------------------------------------------------
# Extracted helpers (Task 3 of MCP OAuth consolidation)
#
# These compose into ``build_oauth_auth`` below, and are also used by
# ``tools.mcp_oauth_manager.MCPOAuthManager._build_provider`` so the two
# construction paths share one implementation.
# ---------------------------------------------------------------------------
def _configure_callback_port(cfg: dict) -> int:
"""Pick or validate the OAuth callback port.
Stores the resolved port into ``cfg['_resolved_port']`` so sibling
helpers (and the manager) can read it from the same dict. Returns the
resolved port.
NOTE: also sets the legacy module-level ``_oauth_port`` so existing
calls to ``_wait_for_callback`` keep working. The legacy global is
the root cause of issue #5344 (port collision on concurrent OAuth
flows); replacing it with a ContextVar is out of scope for this
consolidation PR.
"""
global _oauth_port
requested = int(cfg.get("redirect_port", 0))
port = _find_free_port() if requested == 0 else requested
cfg["_resolved_port"] = port
_oauth_port = port # legacy consumer: _wait_for_callback reads this
return port
def _build_client_metadata(cfg: dict) -> "OAuthClientMetadata":
"""Build OAuthClientMetadata from the oauth config dict.
Requires ``cfg['_resolved_port']`` to have been populated by
:func:`_configure_callback_port` first.
"""
port = cfg.get("_resolved_port")
if port is None:
raise ValueError(
"_configure_callback_port() must be called before _build_client_metadata()"
)
client_name = cfg.get("client_name", "Hermes Agent")
scope = cfg.get("scope")
redirect_uri = f"http://127.0.0.1:{port}/callback"
metadata_kwargs: dict[str, Any] = {
"client_name": client_name,
"redirect_uris": [AnyUrl(redirect_uri)],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
if scope:
metadata_kwargs["scope"] = scope
if cfg.get("client_secret"):
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
return OAuthClientMetadata.model_validate(metadata_kwargs)
def _maybe_preregister_client(
storage: "HermesTokenStorage",
cfg: dict,
client_metadata: "OAuthClientMetadata",
) -> None:
"""If cfg has a pre-registered client_id, persist it to storage."""
client_id = cfg.get("client_id")
if not client_id:
return
port = cfg["_resolved_port"]
redirect_uri = f"http://127.0.0.1:{port}/callback"
info_dict: dict[str, Any] = {
"client_id": client_id,
"redirect_uris": [redirect_uri],
"grant_types": client_metadata.grant_types,
"response_types": client_metadata.response_types,
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
}
if cfg.get("client_secret"):
info_dict["client_secret"] = cfg["client_secret"]
if cfg.get("client_name"):
info_dict["client_name"] = cfg["client_name"]
if cfg.get("scope"):
info_dict["scope"] = cfg["scope"]
client_info = OAuthClientInformationFull.model_validate(info_dict)
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
logger.debug("Pre-registered client_id=%s for '%s'", client_id, storage._server_name)
def _parse_base_url(server_url: str) -> str:
"""Strip path component from server URL, returning the base origin."""
parsed = urlparse(server_url)
return f"{parsed.scheme}://{parsed.netloc}"
def build_oauth_auth(
server_name: str,
server_url: str,
oauth_config: dict | None = None,
) -> "OAuthClientProvider | None":
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
Public API preserved for backwards compatibility. New code should use
:func:`tools.mcp_oauth_manager.get_manager` so OAuth state is shared
across config-time, runtime, and reconnect paths.
Args:
server_name: Server key in mcp_servers config (used for storage).
server_url: MCP server endpoint URL.
oauth_config: Optional dict from the ``oauth:`` block in config.yaml.
Returns:
An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks
OAuth support.
"""
if not _OAUTH_AVAILABLE:
logger.warning(
"MCP OAuth requested for '%s' but SDK auth types are not available. "
"Install with: pip install 'mcp>=1.26.0'",
server_name,
)
return None
cfg = dict(oauth_config or {}) # copy — we mutate _resolved_port
storage = HermesTokenStorage(server_name)
if not _is_interactive() and not storage.has_cached_tokens():
logger.warning(
"MCP OAuth for '%s': non-interactive environment and no cached tokens "
"found. The OAuth flow requires browser authorization. Run "
"interactively first to complete the initial authorization, then "
"cached tokens will be reused.",
server_name,
)
_configure_callback_port(cfg)
client_metadata = _build_client_metadata(cfg)
_maybe_preregister_client(storage, cfg, client_metadata)
return OAuthClientProvider(
server_url=_parse_base_url(server_url),
client_metadata=client_metadata,
storage=storage,
redirect_handler=_redirect_handler,
callback_handler=_wait_for_callback,
timeout=float(cfg.get("timeout", 300)),
)