feat: implement MCP OAuth 2.1 PKCE client support (#5420)

Implement tools/mcp_oauth.py — the OAuth adapter that mcp_tool.py's
existing auth: oauth hook has been waiting for.

Components:
- HermesTokenStorage: persists tokens + client registration to
  HERMES_HOME/mcp-tokens/<server>.json with 0o600 permissions
- Callback handler factory: per-flow isolated HTTP handlers (safe for
  concurrent OAuth flows across multiple MCP servers)
- OAuthClientProvider integration: wraps the MCP SDK's httpx.Auth
  subclass which handles discovery, DCR, PKCE, token exchange,
  refresh, and step-up auth (403 insufficient_scope) automatically
- Non-interactive detection: warns when gateway/cron environments
  try to OAuth without cached tokens
- Pre-registered client support: injects client_id/secret from config
  for servers that don't support Dynamic Client Registration (e.g. Slack)
- Path traversal protection on server names
- remove_oauth_tokens() for cleanup

Config format:
  mcp_servers:
    sentry:
      url: 'https://mcp.sentry.dev/mcp'
      auth: oauth
      oauth:                          # all optional
        client_id: '...'              # skip DCR
        client_secret: '...'          # confidential client
        scope: 'read write'           # server-provided by default

Also passes oauth config dict through from mcp_tool.py (was passing
only server_name and url before).

E2E verified: full OAuth flow (401 → discovery → DCR → authorize →
token exchange → authenticated request → tokens persisted) against
local test servers. 23 unit tests + 186 MCP suite tests pass.
This commit is contained in:
Teknium 2026-04-05 22:08:00 -07:00 committed by GitHub
parent 3962bc84b7
commit 38d8446011
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 547 additions and 293 deletions

View file

@ -1,7 +1,8 @@
"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK."""
"""Tests for tools/mcp_oauth.py — OAuth 2.1 PKCE support for MCP servers."""
import json
import os
from io import BytesIO
from pathlib import Path
from unittest.mock import patch, MagicMock, AsyncMock
@ -16,6 +17,7 @@ from tools.mcp_oauth import (
_can_open_browser,
_is_interactive,
_wait_for_callback,
_make_callback_handler,
)
@ -79,34 +81,93 @@ class TestHermesTokenStorage:
assert not (d / "test-server.json").exists()
assert not (d / "test-server.client.json").exists()
def test_has_cached_tokens(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("my-server")
assert not storage.has_cached_tokens()
d = tmp_path / "mcp-tokens"
d.mkdir(parents=True)
(d / "my-server.json").write_text('{"access_token": "x", "token_type": "Bearer"}')
assert storage.has_cached_tokens()
def test_corrupt_tokens_returns_none(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("bad-server")
d = tmp_path / "mcp-tokens"
d.mkdir(parents=True)
(d / "bad-server.json").write_text("NOT VALID JSON{{{")
import asyncio
assert asyncio.run(storage.get_tokens()) is None
def test_corrupt_client_info_returns_none(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("bad-server")
d = tmp_path / "mcp-tokens"
d.mkdir(parents=True)
(d / "bad-server.client.json").write_text("GARBAGE")
import asyncio
assert asyncio.run(storage.get_client_info()) is None
# ---------------------------------------------------------------------------
# build_oauth_auth
# ---------------------------------------------------------------------------
class TestBuildOAuthAuth:
def test_returns_oauth_provider(self):
def test_returns_oauth_provider(self, tmp_path, monkeypatch):
try:
from mcp.client.auth import OAuthClientProvider
except ImportError:
pytest.skip("MCP SDK auth not available")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
auth = build_oauth_auth("test", "https://example.com/mcp")
assert isinstance(auth, OAuthClientProvider)
def test_returns_none_without_sdk(self, monkeypatch):
import tools.mcp_oauth as mod
orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
monkeypatch.setattr(mod, "_OAUTH_AVAILABLE", False)
result = build_oauth_auth("test", "https://example.com")
assert result is None
def _block_import(name, *args, **kwargs):
if "mcp.client.auth" in name:
raise ImportError("blocked")
return orig_import(name, *args, **kwargs)
def test_pre_registered_client_id_stored(self, tmp_path, monkeypatch):
try:
from mcp.client.auth import OAuthClientProvider
except ImportError:
pytest.skip("MCP SDK auth not available")
with patch("builtins.__import__", side_effect=_block_import):
result = build_oauth_auth("test", "https://example.com")
# May or may not be None depending on import caching, but shouldn't crash
assert result is None or result is not None
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
build_oauth_auth("slack", "https://slack.example.com/mcp", {
"client_id": "my-app-id",
"client_secret": "my-secret",
"scope": "channels:read",
})
client_path = tmp_path / "mcp-tokens" / "slack.client.json"
assert client_path.exists()
data = json.loads(client_path.read_text())
assert data["client_id"] == "my-app-id"
assert data["client_secret"] == "my-secret"
def test_scope_passed_through(self, tmp_path, monkeypatch):
try:
from mcp.client.auth import OAuthClientProvider
except ImportError:
pytest.skip("MCP SDK auth not available")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
provider = build_oauth_auth("scoped", "https://example.com/mcp", {
"scope": "read write admin",
})
assert provider is not None
assert provider.context.client_metadata.scope == "read write admin"
# ---------------------------------------------------------------------------
@ -119,6 +180,12 @@ class TestUtilities:
assert isinstance(port, int)
assert 1024 <= port <= 65535
def test_find_free_port_unique(self):
"""Two consecutive calls should return different ports (usually)."""
ports = {_find_free_port() for _ in range(5)}
# At least 2 different ports out of 5 attempts
assert len(ports) >= 2
def test_can_open_browser_false_in_ssh(self, monkeypatch):
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22")
assert _can_open_browser() is False
@ -127,14 +194,22 @@ class TestUtilities:
monkeypatch.delenv("SSH_CLIENT", raising=False)
monkeypatch.delenv("SSH_TTY", raising=False)
monkeypatch.delenv("DISPLAY", raising=False)
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
# Mock os.name and uname for non-macOS, non-Windows
monkeypatch.setattr(os, "name", "posix")
monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})())
assert _can_open_browser() is False
def test_can_open_browser_true_with_display(self, monkeypatch):
monkeypatch.delenv("SSH_CLIENT", raising=False)
monkeypatch.delenv("SSH_TTY", raising=False)
monkeypatch.setenv("DISPLAY", ":0")
monkeypatch.setattr(os, "name", "posix")
assert _can_open_browser() is True
# ---------------------------------------------------------------------------
# remove_oauth_tokens
# Path traversal protection
# ---------------------------------------------------------------------------
class TestPathTraversal:
@ -169,11 +244,14 @@ class TestPathTraversal:
assert "/" not in path.stem
# ---------------------------------------------------------------------------
# Callback handler isolation
# ---------------------------------------------------------------------------
class TestCallbackHandlerIsolation:
"""Verify concurrent OAuth flows don't share state."""
def test_independent_result_dicts(self):
from tools.mcp_oauth import _make_callback_handler
_, result_a = _make_callback_handler()
_, result_b = _make_callback_handler()
@ -184,10 +262,6 @@ class TestCallbackHandlerIsolation:
assert result_b["auth_code"] == "code_B"
def test_handler_writes_to_own_result(self):
from tools.mcp_oauth import _make_callback_handler
from io import BytesIO
from unittest.mock import MagicMock
HandlerClass, result = _make_callback_handler()
assert result["auth_code"] is None
@ -203,13 +277,30 @@ class TestCallbackHandlerIsolation:
assert result["auth_code"] == "test123"
assert result["state"] == "mystate"
def test_handler_captures_error(self):
HandlerClass, result = _make_callback_handler()
handler = HandlerClass.__new__(HandlerClass)
handler.path = "/callback?error=access_denied"
handler.wfile = BytesIO()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
assert result["auth_code"] is None
assert result["error"] == "access_denied"
# ---------------------------------------------------------------------------
# Port sharing
# ---------------------------------------------------------------------------
class TestOAuthPortSharing:
"""Verify build_oauth_auth and _wait_for_callback use the same port."""
def test_port_stored_globally(self):
def test_port_stored_globally(self, tmp_path, monkeypatch):
import tools.mcp_oauth as mod
# Reset
mod._oauth_port = None
try:
@ -217,12 +308,17 @@ class TestOAuthPortSharing:
except ImportError:
pytest.skip("MCP SDK auth not available")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
build_oauth_auth("test-port", "https://example.com/mcp")
assert mod._oauth_port is not None
assert isinstance(mod._oauth_port, int)
assert 1024 <= mod._oauth_port <= 65535
# ---------------------------------------------------------------------------
# remove_oauth_tokens
# ---------------------------------------------------------------------------
class TestRemoveOAuthTokens:
def test_removes_files(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
@ -242,7 +338,7 @@ class TestRemoveOAuthTokens:
# ---------------------------------------------------------------------------
# Non-interactive / startup-safety tests (issue #4462)
# Non-interactive / startup-safety tests
# ---------------------------------------------------------------------------
class TestIsInteractive:

View file

@ -1,326 +1,482 @@
"""Thin OAuth adapter for MCP HTTP servers.
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
``httpx.Auth``) with Hermes-specific token storage and browser-based
authorization. The SDK handles all of the heavy lifting: PKCE generation,
metadata discovery, dynamic client registration, token exchange, and refresh.
Startup safety:
The callback handler never calls blocking ``input()`` on the event loop.
In non-interactive environments (no TTY, SSH, headless), the OAuth flow
raises ``OAuthNonInteractiveError`` instead of blocking, so that the
server degrades gracefully and other MCP servers are not affected.
Usage in mcp_tool.py::
from tools.mcp_oauth import build_oauth_auth
auth = build_oauth_auth(server_name, server_url)
# pass ``auth`` as the httpx auth parameter
#!/usr/bin/env python3
"""
MCP OAuth 2.1 Client Support
from __future__ import annotations
Implements the browser-based OAuth 2.1 authorization code flow with PKCE
for MCP servers that require OAuth authentication instead of static bearer
tokens.
Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass)
which handles discovery, dynamic client registration, PKCE, token exchange,
refresh, and step-up authorization automatically.
This module provides the glue:
- ``HermesTokenStorage``: persists tokens/client-info to disk so they
survive across process restarts.
- Callback server: ephemeral localhost HTTP server to capture the OAuth
redirect with the authorization code.
- ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires
everything together and returns the ``httpx.Auth`` object.
Configuration in config.yaml::
mcp_servers:
my_server:
url: "https://mcp.example.com/mcp"
auth: oauth
oauth: # all fields optional
client_id: "pre-registered-id" # skip dynamic registration
client_secret: "secret" # confidential clients only
scope: "read write" # default: server-provided
redirect_port: 0 # 0 = auto-pick free port
client_name: "My Custom Client" # default: "Hermes Agent"
"""
import asyncio
import json
import logging
import os
import re
import socket
import sys
import threading
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy imports -- MCP SDK with OAuth support is optional
# ---------------------------------------------------------------------------
_OAUTH_AVAILABLE = False
try:
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthToken,
)
from pydantic import AnyUrl
_OAUTH_AVAILABLE = True
except ImportError:
logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled")
# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------
class OAuthNonInteractiveError(RuntimeError):
"""Raised when OAuth requires user interaction but the environment is non-interactive."""
pass
_TOKEN_DIR_NAME = "mcp-tokens"
"""Raised when OAuth requires browser interaction in a non-interactive env."""
# ---------------------------------------------------------------------------
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
# Module-level state
# ---------------------------------------------------------------------------
def _sanitize_server_name(name: str) -> str:
"""Sanitize server name for safe use as a filename."""
import re
clean = re.sub(r"[^\w\-]", "-", name.strip().lower())
clean = re.sub(r"-+", "-", clean).strip("-")
return clean[:60] or "unnamed"
class HermesTokenStorage:
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
def __init__(self, server_name: str):
self._server_name = _sanitize_server_name(server_name)
def _base_dir(self) -> Path:
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
d = home / _TOKEN_DIR_NAME
d.mkdir(parents=True, exist_ok=True)
return d
def _tokens_path(self) -> Path:
return self._base_dir() / f"{self._server_name}.json"
def _client_path(self) -> Path:
return self._base_dir() / f"{self._server_name}.client.json"
# -- TokenStorage protocol (async) --
async def get_tokens(self):
data = self._read_json(self._tokens_path())
if not data:
return None
try:
from mcp.shared.auth import OAuthToken
return OAuthToken(**data)
except Exception:
return None
async def set_tokens(self, tokens) -> None:
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
async def get_client_info(self):
data = self._read_json(self._client_path())
if not data:
return None
try:
from mcp.shared.auth import OAuthClientInformationFull
return OAuthClientInformationFull(**data)
except Exception:
return None
async def set_client_info(self, client_info) -> None:
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
# -- helpers --
@staticmethod
def _read_json(path: Path) -> dict | None:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return None
@staticmethod
def _write_json(path: Path, data: dict) -> None:
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
try:
path.chmod(0o600)
except OSError:
pass
def remove(self) -> None:
"""Delete stored tokens and client info for this server."""
for p in (self._tokens_path(), self._client_path()):
try:
p.unlink(missing_ok=True)
except OSError:
pass
# Port used by the most recent build_oauth_auth() call. Exposed so that
# tests can verify the callback server and the redirect_uri share a port.
_oauth_port: int | None = None
# ---------------------------------------------------------------------------
# Browser-based callback handler
# Helpers
# ---------------------------------------------------------------------------
def _get_token_dir() -> Path:
"""Return the directory for MCP OAuth token files.
Uses HERMES_HOME so each profile gets its own OAuth tokens.
Layout: ``HERMES_HOME/mcp-tokens/``
"""
try:
from hermes_constants import get_hermes_home
base = Path(get_hermes_home())
except ImportError:
base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes")))
return base / "mcp-tokens"
def _safe_filename(name: str) -> str:
"""Sanitize a server name for use as a filename (no path separators)."""
return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default"
def _find_free_port() -> int:
"""Find an available TCP port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _make_callback_handler():
"""Create a callback handler class with instance-scoped result storage."""
result = {"auth_code": None, "state": None}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
qs = parse_qs(urlparse(self.path).query)
result["auth_code"] = (qs.get("code") or [None])[0]
result["state"] = (qs.get("state") or [None])[0]
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
def log_message(self, *_args: Any) -> None:
pass
return Handler, result
# Port chosen at build time and shared with the callback handler via closure.
_oauth_port: int | None = None
async def _redirect_to_browser(auth_url: str) -> None:
"""Open the authorization URL in the user's browser."""
def _is_interactive() -> bool:
"""Return True if we can reasonably expect to interact with a user."""
try:
if _can_open_browser():
webbrowser.open(auth_url)
print(" Opened browser for authorization...")
else:
print(f"\n Open this URL to authorize:\n {auth_url}\n")
except Exception:
print(f"\n Open this URL to authorize:\n {auth_url}\n")
async def _wait_for_callback() -> tuple[str, str | None]:
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect.
If the callback times out, raises ``OAuthNonInteractiveError`` instead of
calling blocking ``input()`` the old ``input()`` call would block the
entire MCP asyncio event loop, preventing all other MCP servers from
connecting and potentially hanging Hermes startup indefinitely.
"""
global _oauth_port
port = _oauth_port or _find_free_port()
HandlerClass, result = _make_callback_handler()
server = HTTPServer(("127.0.0.1", port), HandlerClass)
def _serve():
server.timeout = 120
server.handle_request()
thread = threading.Thread(target=_serve, daemon=True)
thread.start()
for _ in range(1200): # 120 seconds
await asyncio.sleep(0.1)
if result["auth_code"] is not None:
break
server.server_close()
code = result["auth_code"] or ""
state = result["state"]
if not code:
raise OAuthNonInteractiveError(
"OAuth browser callback timed out after 120 seconds. "
"Run 'hermes mcp auth <server-name>' to authorize interactively."
)
return code, state
return sys.stdin.isatty()
except (AttributeError, ValueError):
return False
def _can_open_browser() -> bool:
"""Return True if opening a browser is likely to work."""
# Explicit SSH session → no local display
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
return False
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
return False
return True
# macOS and Windows usually have a display
if os.name == "nt":
return True
try:
if os.uname().sysname == "Darwin":
return True
except AttributeError:
pass
# Linux/other posix: need DISPLAY or WAYLAND_DISPLAY
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
return True
return False
def _is_interactive() -> bool:
"""Check if the current environment can support interactive OAuth flows.
def _read_json(path: Path) -> dict | None:
"""Read a JSON file, returning None if it doesn't exist or is invalid."""
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError) as exc:
logger.warning("Failed to read %s: %s", path, exc)
return None
Returns False in headless/daemon/container environments where no user
can interact with a browser or paste an auth code.
def _write_json(path: Path, data: dict) -> None:
"""Write a dict as JSON with restricted permissions (0o600)."""
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(".tmp")
try:
tmp.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8")
os.chmod(tmp, 0o600)
tmp.rename(path)
except OSError:
tmp.unlink(missing_ok=True)
raise
# ---------------------------------------------------------------------------
# HermesTokenStorage -- persistent token/client-info on disk
# ---------------------------------------------------------------------------
class HermesTokenStorage:
"""Persist OAuth tokens and client registration to JSON files.
File layout::
HERMES_HOME/mcp-tokens/<server_name>.json -- tokens
HERMES_HOME/mcp-tokens/<server_name>.client.json -- client info
"""
if not hasattr(sys.stdin, "isatty") or not sys.stdin.isatty():
return False
return True
def __init__(self, server_name: str):
self._server_name = _safe_filename(server_name)
def _tokens_path(self) -> Path:
return _get_token_dir() / f"{self._server_name}.json"
def _client_info_path(self) -> Path:
return _get_token_dir() / f"{self._server_name}.client.json"
# -- tokens ------------------------------------------------------------
async def get_tokens(self) -> "OAuthToken | None":
data = _read_json(self._tokens_path())
if data is None:
return None
try:
return OAuthToken.model_validate(data)
except Exception:
logger.warning("Corrupt tokens at %s -- ignoring", self._tokens_path())
return None
async def set_tokens(self, tokens: "OAuthToken") -> None:
_write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
logger.debug("OAuth tokens saved for %s", self._server_name)
# -- client info -------------------------------------------------------
async def get_client_info(self) -> "OAuthClientInformationFull | None":
data = _read_json(self._client_info_path())
if data is None:
return None
try:
return OAuthClientInformationFull.model_validate(data)
except Exception:
logger.warning("Corrupt client info at %s -- ignoring", self._client_info_path())
return None
async def set_client_info(self, client_info: "OAuthClientInformationFull") -> None:
_write_json(self._client_info_path(), client_info.model_dump(exclude_none=True))
logger.debug("OAuth client info saved for %s", self._server_name)
# -- cleanup -----------------------------------------------------------
def remove(self) -> None:
"""Delete all stored OAuth state for this server."""
for p in (self._tokens_path(), self._client_info_path()):
p.unlink(missing_ok=True)
def has_cached_tokens(self) -> bool:
"""Return True if we have tokens on disk (may be expired)."""
return self._tokens_path().exists()
# ---------------------------------------------------------------------------
# Callback handler factory -- each invocation gets its own result dict
# ---------------------------------------------------------------------------
def _make_callback_handler() -> tuple[type, dict]:
"""Create a per-flow callback HTTP handler class with its own result dict.
Returns ``(HandlerClass, result_dict)`` where *result_dict* is a mutable
dict that the handler writes ``auth_code`` and ``state`` into when the
OAuth redirect arrives. Each call returns a fresh pair so concurrent
flows don't stomp on each other.
"""
result: dict[str, Any] = {"auth_code": None, "state": None, "error": None}
class _Handler(BaseHTTPRequestHandler):
def do_GET(self) -> None: # noqa: N802
params = parse_qs(urlparse(self.path).query)
code = params.get("code", [None])[0]
state = params.get("state", [None])[0]
error = params.get("error", [None])[0]
result["auth_code"] = code
result["state"] = state
result["error"] = error
body = (
"<html><body><h2>Authorization Successful</h2>"
"<p>You can close this tab and return to Hermes.</p></body></html>"
) if code else (
"<html><body><h2>Authorization Failed</h2>"
f"<p>Error: {error or 'unknown'}</p></body></html>"
)
self.send_response(200)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.end_headers()
self.wfile.write(body.encode())
def log_message(self, fmt: str, *args: Any) -> None:
logger.debug("OAuth callback: %s", fmt % args)
return _Handler, result
# ---------------------------------------------------------------------------
# Async redirect + callback handlers for OAuthClientProvider
# ---------------------------------------------------------------------------
async def _redirect_handler(authorization_url: str) -> None:
"""Show the authorization URL to the user.
Opens the browser automatically when possible; always prints the URL
as a fallback for headless/SSH/gateway environments.
"""
msg = (
f"\n MCP OAuth: authorization required.\n"
f" Open this URL in your browser:\n\n"
f" {authorization_url}\n"
)
print(msg, file=sys.stderr)
if _can_open_browser():
try:
opened = webbrowser.open(authorization_url)
if opened:
print(" (Browser opened automatically.)\n", file=sys.stderr)
else:
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
except Exception:
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
else:
print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr)
async def _wait_for_callback() -> tuple[str, str | None]:
"""Wait for the OAuth callback to arrive on the local callback server.
Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth``
before this is ever called. Polls for the result without blocking the
event loop.
Raises:
OAuthNonInteractiveError: If the callback times out (no user present
to complete the browser auth).
"""
global _oauth_port
assert _oauth_port is not None, "OAuth callback port not set"
# The callback server is already running (started in build_oauth_auth).
# We just need to poll for the result.
handler_cls, result = _make_callback_handler()
# Start a temporary server on the known port
try:
server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls)
except OSError:
# Port already in use — the server from build_oauth_auth is running.
# Fall back to polling the server started by build_oauth_auth.
raise OAuthNonInteractiveError(
"OAuth callback timed out — could not bind callback port. "
"Complete the authorization in a browser first, then retry."
)
server_thread = threading.Thread(target=server.handle_request, daemon=True)
server_thread.start()
timeout = 300.0
poll_interval = 0.5
elapsed = 0.0
while elapsed < timeout:
if result["auth_code"] is not None or result["error"] is not None:
break
await asyncio.sleep(poll_interval)
elapsed += poll_interval
server.server_close()
if result["error"]:
raise RuntimeError(f"OAuth authorization failed: {result['error']}")
if result["auth_code"] is None:
raise OAuthNonInteractiveError(
"OAuth callback timed out — no authorization code received. "
"Ensure you completed the browser authorization flow."
)
return result["auth_code"], result["state"]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def build_oauth_auth(server_name: str, server_url: str):
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
registration, PKCE, token exchange, and refresh automatically.
In non-interactive environments (no TTY), this still returns a provider
so that **cached tokens and refresh flows work**. Only the interactive
authorization-code grant will fail fast with a clear error instead of
blocking the event loop.
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
or ``None`` if the MCP SDK auth module is not available.
"""
try:
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import OAuthClientMetadata
except ImportError:
logger.warning("MCP SDK auth module not available — OAuth disabled")
return None
storage = HermesTokenStorage(server_name)
interactive = _is_interactive()
if not interactive:
# Check whether cached tokens exist. If they do, the SDK can still
# use them (and refresh them) without any user interaction. If not,
# we still build the provider — the callback_handler will raise
# OAuthNonInteractiveError if a fresh authorization is actually
# needed, which surfaces as a clean connection failure for this
# server only (other MCP servers are unaffected).
has_cached = storage._read_json(storage._tokens_path()) is not None
if not has_cached:
logger.warning(
"MCP server '%s' requires OAuth but no cached tokens found "
"and environment is non-interactive. The server will fail to "
"connect. Run 'hermes mcp auth %s' to authorize interactively.",
server_name, server_name,
)
global _oauth_port
_oauth_port = _find_free_port()
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
client_metadata = OAuthClientMetadata(
client_name="Hermes Agent",
redirect_uris=[redirect_uri],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope="openid profile email offline_access",
token_endpoint_auth_method="none",
)
# In non-interactive mode, the redirect handler logs the URL and the
# callback handler raises immediately — no blocking, no input().
redirect_handler = _redirect_to_browser
callback_handler = _wait_for_callback
if not interactive:
async def _noninteractive_redirect(auth_url: str) -> None:
logger.warning(
"MCP server '%s' needs OAuth authorization (non-interactive, "
"cannot open browser). URL: %s",
server_name, auth_url,
)
async def _noninteractive_callback() -> tuple[str, str | None]:
raise OAuthNonInteractiveError(
f"MCP server '{server_name}' requires interactive OAuth "
f"authorization but the environment is non-interactive "
f"(no TTY). Run 'hermes mcp auth {server_name}' to "
f"authorize, then restart."
)
redirect_handler = _noninteractive_redirect
callback_handler = _noninteractive_callback
return OAuthClientProvider(
server_url=server_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
timeout=120.0,
)
def remove_oauth_tokens(server_name: str) -> None:
"""Delete stored OAuth tokens and client info for a server."""
HermesTokenStorage(server_name).remove()
storage = HermesTokenStorage(server_name)
storage.remove()
logger.info("OAuth tokens removed for '%s'", server_name)
def build_oauth_auth(
server_name: str,
server_url: str,
oauth_config: dict | None = None,
) -> "OAuthClientProvider | None":
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config.
Args:
server_name: Server key in mcp_servers config (used for storage).
server_url: MCP server endpoint URL.
oauth_config: Optional dict from the ``oauth:`` block in config.yaml.
Returns:
An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks
OAuth support.
"""
if not _OAUTH_AVAILABLE:
logger.warning(
"MCP OAuth requested for '%s' but SDK auth types are not available. "
"Install with: pip install 'mcp>=1.10.0'",
server_name,
)
return None
global _oauth_port
cfg = oauth_config or {}
# --- Storage ---
storage = HermesTokenStorage(server_name)
# --- Non-interactive warning ---
if not _is_interactive() and not storage.has_cached_tokens():
logger.warning(
"MCP OAuth for '%s': non-interactive environment and no cached tokens found. "
"The OAuth flow requires browser authorization. Run interactively first "
"to complete the initial authorization, then cached tokens will be reused.",
server_name,
)
# --- Pick callback port ---
redirect_port = int(cfg.get("redirect_port", 0))
if redirect_port == 0:
redirect_port = _find_free_port()
_oauth_port = redirect_port
# --- Client metadata ---
client_name = cfg.get("client_name", "Hermes Agent")
scope = cfg.get("scope")
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
metadata_kwargs: dict[str, Any] = {
"client_name": client_name,
"redirect_uris": [AnyUrl(redirect_uri)],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
if scope:
metadata_kwargs["scope"] = scope
client_secret = cfg.get("client_secret")
if client_secret:
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs)
# --- Pre-registered client ---
client_id = cfg.get("client_id")
if client_id:
info_dict: dict[str, Any] = {
"client_id": client_id,
"redirect_uris": [redirect_uri],
"grant_types": client_metadata.grant_types,
"response_types": client_metadata.response_types,
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
}
if client_secret:
info_dict["client_secret"] = client_secret
if client_name:
info_dict["client_name"] = client_name
if scope:
info_dict["scope"] = scope
client_info = OAuthClientInformationFull.model_validate(info_dict)
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name)
# --- Base URL for discovery ---
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# --- Build provider ---
provider = OAuthClientProvider(
server_url=base_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=_redirect_handler,
callback_handler=_wait_for_callback,
timeout=float(cfg.get("timeout", 300)),
)
return provider

View file

@ -892,7 +892,9 @@ class MCPServerTask:
if self._auth_type == "oauth":
try:
from tools.mcp_oauth import build_oauth_auth
_oauth_auth = build_oauth_auth(self.name, url)
_oauth_auth = build_oauth_auth(
self.name, url, config.get("oauth")
)
except Exception as exc:
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
raise