diff --git a/tests/tools/test_mcp_oauth.py b/tests/tools/test_mcp_oauth.py index 19c588e58..8643c26b3 100644 --- a/tests/tools/test_mcp_oauth.py +++ b/tests/tools/test_mcp_oauth.py @@ -1,7 +1,8 @@ -"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK.""" +"""Tests for tools/mcp_oauth.py — OAuth 2.1 PKCE support for MCP servers.""" import json import os +from io import BytesIO from pathlib import Path from unittest.mock import patch, MagicMock, AsyncMock @@ -16,6 +17,7 @@ from tools.mcp_oauth import ( _can_open_browser, _is_interactive, _wait_for_callback, + _make_callback_handler, ) @@ -79,34 +81,93 @@ class TestHermesTokenStorage: assert not (d / "test-server.json").exists() assert not (d / "test-server.client.json").exists() + def test_has_cached_tokens(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("my-server") + + assert not storage.has_cached_tokens() + + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "my-server.json").write_text('{"access_token": "x", "token_type": "Bearer"}') + + assert storage.has_cached_tokens() + + def test_corrupt_tokens_returns_none(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("bad-server") + + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "bad-server.json").write_text("NOT VALID JSON{{{") + + import asyncio + assert asyncio.run(storage.get_tokens()) is None + + def test_corrupt_client_info_returns_none(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("bad-server") + + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "bad-server.client.json").write_text("GARBAGE") + + import asyncio + assert asyncio.run(storage.get_client_info()) is None + # --------------------------------------------------------------------------- # build_oauth_auth # --------------------------------------------------------------------------- class TestBuildOAuthAuth: - def test_returns_oauth_provider(self): + def test_returns_oauth_provider(self, tmp_path, monkeypatch): try: from mcp.client.auth import OAuthClientProvider except ImportError: pytest.skip("MCP SDK auth not available") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) auth = build_oauth_auth("test", "https://example.com/mcp") assert isinstance(auth, OAuthClientProvider) def test_returns_none_without_sdk(self, monkeypatch): import tools.mcp_oauth as mod - orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + monkeypatch.setattr(mod, "_OAUTH_AVAILABLE", False) + result = build_oauth_auth("test", "https://example.com") + assert result is None - def _block_import(name, *args, **kwargs): - if "mcp.client.auth" in name: - raise ImportError("blocked") - return orig_import(name, *args, **kwargs) + def test_pre_registered_client_id_stored(self, tmp_path, monkeypatch): + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") - with patch("builtins.__import__", side_effect=_block_import): - result = build_oauth_auth("test", "https://example.com") - # May or may not be None depending on import caching, but shouldn't crash - assert result is None or result is not None + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + build_oauth_auth("slack", "https://slack.example.com/mcp", { + "client_id": "my-app-id", + "client_secret": "my-secret", + "scope": "channels:read", + }) + + client_path = tmp_path / "mcp-tokens" / "slack.client.json" + assert client_path.exists() + data = json.loads(client_path.read_text()) + assert data["client_id"] == "my-app-id" + assert data["client_secret"] == "my-secret" + + def test_scope_passed_through(self, tmp_path, monkeypatch): + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + provider = build_oauth_auth("scoped", "https://example.com/mcp", { + "scope": "read write admin", + }) + assert provider is not None + assert provider.context.client_metadata.scope == "read write admin" # --------------------------------------------------------------------------- @@ -119,6 +180,12 @@ class TestUtilities: assert isinstance(port, int) assert 1024 <= port <= 65535 + def test_find_free_port_unique(self): + """Two consecutive calls should return different ports (usually).""" + ports = {_find_free_port() for _ in range(5)} + # At least 2 different ports out of 5 attempts + assert len(ports) >= 2 + def test_can_open_browser_false_in_ssh(self, monkeypatch): monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22") assert _can_open_browser() is False @@ -127,14 +194,22 @@ class TestUtilities: monkeypatch.delenv("SSH_CLIENT", raising=False) monkeypatch.delenv("SSH_TTY", raising=False) monkeypatch.delenv("DISPLAY", raising=False) + monkeypatch.delenv("WAYLAND_DISPLAY", raising=False) # Mock os.name and uname for non-macOS, non-Windows monkeypatch.setattr(os, "name", "posix") monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})()) assert _can_open_browser() is False + def test_can_open_browser_true_with_display(self, monkeypatch): + monkeypatch.delenv("SSH_CLIENT", raising=False) + monkeypatch.delenv("SSH_TTY", raising=False) + monkeypatch.setenv("DISPLAY", ":0") + monkeypatch.setattr(os, "name", "posix") + assert _can_open_browser() is True + # --------------------------------------------------------------------------- -# remove_oauth_tokens +# Path traversal protection # --------------------------------------------------------------------------- class TestPathTraversal: @@ -169,11 +244,14 @@ class TestPathTraversal: assert "/" not in path.stem +# --------------------------------------------------------------------------- +# Callback handler isolation +# --------------------------------------------------------------------------- + class TestCallbackHandlerIsolation: """Verify concurrent OAuth flows don't share state.""" def test_independent_result_dicts(self): - from tools.mcp_oauth import _make_callback_handler _, result_a = _make_callback_handler() _, result_b = _make_callback_handler() @@ -184,10 +262,6 @@ class TestCallbackHandlerIsolation: assert result_b["auth_code"] == "code_B" def test_handler_writes_to_own_result(self): - from tools.mcp_oauth import _make_callback_handler - from io import BytesIO - from unittest.mock import MagicMock - HandlerClass, result = _make_callback_handler() assert result["auth_code"] is None @@ -203,13 +277,30 @@ class TestCallbackHandlerIsolation: assert result["auth_code"] == "test123" assert result["state"] == "mystate" + def test_handler_captures_error(self): + HandlerClass, result = _make_callback_handler() + + handler = HandlerClass.__new__(HandlerClass) + handler.path = "/callback?error=access_denied" + handler.wfile = BytesIO() + handler.send_response = MagicMock() + handler.send_header = MagicMock() + handler.end_headers = MagicMock() + handler.do_GET() + + assert result["auth_code"] is None + assert result["error"] == "access_denied" + + +# --------------------------------------------------------------------------- +# Port sharing +# --------------------------------------------------------------------------- class TestOAuthPortSharing: """Verify build_oauth_auth and _wait_for_callback use the same port.""" - def test_port_stored_globally(self): + def test_port_stored_globally(self, tmp_path, monkeypatch): import tools.mcp_oauth as mod - # Reset mod._oauth_port = None try: @@ -217,12 +308,17 @@ class TestOAuthPortSharing: except ImportError: pytest.skip("MCP SDK auth not available") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) build_oauth_auth("test-port", "https://example.com/mcp") assert mod._oauth_port is not None assert isinstance(mod._oauth_port, int) assert 1024 <= mod._oauth_port <= 65535 +# --------------------------------------------------------------------------- +# remove_oauth_tokens +# --------------------------------------------------------------------------- + class TestRemoveOAuthTokens: def test_removes_files(self, tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) @@ -242,7 +338,7 @@ class TestRemoveOAuthTokens: # --------------------------------------------------------------------------- -# Non-interactive / startup-safety tests (issue #4462) +# Non-interactive / startup-safety tests # --------------------------------------------------------------------------- class TestIsInteractive: diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index b614826a8..00172f340 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -1,326 +1,482 @@ -"""Thin OAuth adapter for MCP HTTP servers. - -Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements -``httpx.Auth``) with Hermes-specific token storage and browser-based -authorization. The SDK handles all of the heavy lifting: PKCE generation, -metadata discovery, dynamic client registration, token exchange, and refresh. - -Startup safety: - The callback handler never calls blocking ``input()`` on the event loop. - In non-interactive environments (no TTY, SSH, headless), the OAuth flow - raises ``OAuthNonInteractiveError`` instead of blocking, so that the - server degrades gracefully and other MCP servers are not affected. - -Usage in mcp_tool.py:: - - from tools.mcp_oauth import build_oauth_auth - auth = build_oauth_auth(server_name, server_url) - # pass ``auth`` as the httpx auth parameter +#!/usr/bin/env python3 """ +MCP OAuth 2.1 Client Support -from __future__ import annotations +Implements the browser-based OAuth 2.1 authorization code flow with PKCE +for MCP servers that require OAuth authentication instead of static bearer +tokens. + +Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass) +which handles discovery, dynamic client registration, PKCE, token exchange, +refresh, and step-up authorization automatically. + +This module provides the glue: + - ``HermesTokenStorage``: persists tokens/client-info to disk so they + survive across process restarts. + - Callback server: ephemeral localhost HTTP server to capture the OAuth + redirect with the authorization code. + - ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires + everything together and returns the ``httpx.Auth`` object. + +Configuration in config.yaml:: + + mcp_servers: + my_server: + url: "https://mcp.example.com/mcp" + auth: oauth + oauth: # all fields optional + client_id: "pre-registered-id" # skip dynamic registration + client_secret: "secret" # confidential clients only + scope: "read write" # default: server-provided + redirect_port: 0 # 0 = auto-pick free port + client_name: "My Custom Client" # default: "Hermes Agent" +""" import asyncio import json import logging import os +import re import socket import sys import threading import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Any +from typing import Any, Optional from urllib.parse import parse_qs, urlparse logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Lazy imports -- MCP SDK with OAuth support is optional +# --------------------------------------------------------------------------- + +_OAUTH_AVAILABLE = False +try: + from mcp.client.auth import OAuthClientProvider, TokenStorage + from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthToken, + ) + from pydantic import AnyUrl + + _OAUTH_AVAILABLE = True +except ImportError: + logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled") + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + class OAuthNonInteractiveError(RuntimeError): - """Raised when OAuth requires user interaction but the environment is non-interactive.""" - pass - -_TOKEN_DIR_NAME = "mcp-tokens" + """Raised when OAuth requires browser interaction in a non-interactive env.""" # --------------------------------------------------------------------------- -# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/ +# Module-level state # --------------------------------------------------------------------------- -def _sanitize_server_name(name: str) -> str: - """Sanitize server name for safe use as a filename.""" - import re - clean = re.sub(r"[^\w\-]", "-", name.strip().lower()) - clean = re.sub(r"-+", "-", clean).strip("-") - return clean[:60] or "unnamed" - - -class HermesTokenStorage: - """File-backed token storage implementing the MCP SDK's TokenStorage protocol.""" - - def __init__(self, server_name: str): - self._server_name = _sanitize_server_name(server_name) - - def _base_dir(self) -> Path: - home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) - d = home / _TOKEN_DIR_NAME - d.mkdir(parents=True, exist_ok=True) - return d - - def _tokens_path(self) -> Path: - return self._base_dir() / f"{self._server_name}.json" - - def _client_path(self) -> Path: - return self._base_dir() / f"{self._server_name}.client.json" - - # -- TokenStorage protocol (async) -- - - async def get_tokens(self): - data = self._read_json(self._tokens_path()) - if not data: - return None - try: - from mcp.shared.auth import OAuthToken - return OAuthToken(**data) - except Exception: - return None - - async def set_tokens(self, tokens) -> None: - self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) - - async def get_client_info(self): - data = self._read_json(self._client_path()) - if not data: - return None - try: - from mcp.shared.auth import OAuthClientInformationFull - return OAuthClientInformationFull(**data) - except Exception: - return None - - async def set_client_info(self, client_info) -> None: - self._write_json(self._client_path(), client_info.model_dump(exclude_none=True)) - - # -- helpers -- - - @staticmethod - def _read_json(path: Path) -> dict | None: - if not path.exists(): - return None - try: - return json.loads(path.read_text(encoding="utf-8")) - except Exception: - return None - - @staticmethod - def _write_json(path: Path, data: dict) -> None: - path.write_text(json.dumps(data, indent=2), encoding="utf-8") - try: - path.chmod(0o600) - except OSError: - pass - - def remove(self) -> None: - """Delete stored tokens and client info for this server.""" - for p in (self._tokens_path(), self._client_path()): - try: - p.unlink(missing_ok=True) - except OSError: - pass +# Port used by the most recent build_oauth_auth() call. Exposed so that +# tests can verify the callback server and the redirect_uri share a port. +_oauth_port: int | None = None # --------------------------------------------------------------------------- -# Browser-based callback handler +# Helpers # --------------------------------------------------------------------------- + +def _get_token_dir() -> Path: + """Return the directory for MCP OAuth token files. + + Uses HERMES_HOME so each profile gets its own OAuth tokens. + Layout: ``HERMES_HOME/mcp-tokens/`` + """ + try: + from hermes_constants import get_hermes_home + base = Path(get_hermes_home()) + except ImportError: + base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))) + return base / "mcp-tokens" + + +def _safe_filename(name: str) -> str: + """Sanitize a server name for use as a filename (no path separators).""" + return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default" + + def _find_free_port() -> int: + """Find an available TCP port on localhost.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] -def _make_callback_handler(): - """Create a callback handler class with instance-scoped result storage.""" - result = {"auth_code": None, "state": None} - - class Handler(BaseHTTPRequestHandler): - def do_GET(self): - qs = parse_qs(urlparse(self.path).query) - result["auth_code"] = (qs.get("code") or [None])[0] - result["state"] = (qs.get("state") or [None])[0] - self.send_response(200) - self.send_header("Content-Type", "text/html") - self.end_headers() - self.wfile.write(b"

Authorization complete. You can close this tab.

") - - def log_message(self, *_args: Any) -> None: - pass - - return Handler, result - - -# Port chosen at build time and shared with the callback handler via closure. -_oauth_port: int | None = None - - -async def _redirect_to_browser(auth_url: str) -> None: - """Open the authorization URL in the user's browser.""" +def _is_interactive() -> bool: + """Return True if we can reasonably expect to interact with a user.""" try: - if _can_open_browser(): - webbrowser.open(auth_url) - print(" Opened browser for authorization...") - else: - print(f"\n Open this URL to authorize:\n {auth_url}\n") - except Exception: - print(f"\n Open this URL to authorize:\n {auth_url}\n") - - -async def _wait_for_callback() -> tuple[str, str | None]: - """Start a local HTTP server on the pre-registered port and wait for the OAuth redirect. - - If the callback times out, raises ``OAuthNonInteractiveError`` instead of - calling blocking ``input()`` — the old ``input()`` call would block the - entire MCP asyncio event loop, preventing all other MCP servers from - connecting and potentially hanging Hermes startup indefinitely. - """ - global _oauth_port - port = _oauth_port or _find_free_port() - HandlerClass, result = _make_callback_handler() - server = HTTPServer(("127.0.0.1", port), HandlerClass) - - def _serve(): - server.timeout = 120 - server.handle_request() - - thread = threading.Thread(target=_serve, daemon=True) - thread.start() - - for _ in range(1200): # 120 seconds - await asyncio.sleep(0.1) - if result["auth_code"] is not None: - break - - server.server_close() - code = result["auth_code"] or "" - state = result["state"] - if not code: - raise OAuthNonInteractiveError( - "OAuth browser callback timed out after 120 seconds. " - "Run 'hermes mcp auth ' to authorize interactively." - ) - return code, state + return sys.stdin.isatty() + except (AttributeError, ValueError): + return False def _can_open_browser() -> bool: + """Return True if opening a browser is likely to work.""" + # Explicit SSH session → no local display if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"): return False - if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower(): - return False - return True + # macOS and Windows usually have a display + if os.name == "nt": + return True + try: + if os.uname().sysname == "Darwin": + return True + except AttributeError: + pass + # Linux/other posix: need DISPLAY or WAYLAND_DISPLAY + if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): + return True + return False -def _is_interactive() -> bool: - """Check if the current environment can support interactive OAuth flows. +def _read_json(path: Path) -> dict | None: + """Read a JSON file, returning None if it doesn't exist or is invalid.""" + if not path.exists(): + return None + try: + return json.loads(path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to read %s: %s", path, exc) + return None - Returns False in headless/daemon/container environments where no user - can interact with a browser or paste an auth code. + +def _write_json(path: Path, data: dict) -> None: + """Write a dict as JSON with restricted permissions (0o600).""" + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(".tmp") + try: + tmp.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") + os.chmod(tmp, 0o600) + tmp.rename(path) + except OSError: + tmp.unlink(missing_ok=True) + raise + + +# --------------------------------------------------------------------------- +# HermesTokenStorage -- persistent token/client-info on disk +# --------------------------------------------------------------------------- + + +class HermesTokenStorage: + """Persist OAuth tokens and client registration to JSON files. + + File layout:: + + HERMES_HOME/mcp-tokens/.json -- tokens + HERMES_HOME/mcp-tokens/.client.json -- client info """ - if not hasattr(sys.stdin, "isatty") or not sys.stdin.isatty(): - return False - return True + + def __init__(self, server_name: str): + self._server_name = _safe_filename(server_name) + + def _tokens_path(self) -> Path: + return _get_token_dir() / f"{self._server_name}.json" + + def _client_info_path(self) -> Path: + return _get_token_dir() / f"{self._server_name}.client.json" + + # -- tokens ------------------------------------------------------------ + + async def get_tokens(self) -> "OAuthToken | None": + data = _read_json(self._tokens_path()) + if data is None: + return None + try: + return OAuthToken.model_validate(data) + except Exception: + logger.warning("Corrupt tokens at %s -- ignoring", self._tokens_path()) + return None + + async def set_tokens(self, tokens: "OAuthToken") -> None: + _write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) + logger.debug("OAuth tokens saved for %s", self._server_name) + + # -- client info ------------------------------------------------------- + + async def get_client_info(self) -> "OAuthClientInformationFull | None": + data = _read_json(self._client_info_path()) + if data is None: + return None + try: + return OAuthClientInformationFull.model_validate(data) + except Exception: + logger.warning("Corrupt client info at %s -- ignoring", self._client_info_path()) + return None + + async def set_client_info(self, client_info: "OAuthClientInformationFull") -> None: + _write_json(self._client_info_path(), client_info.model_dump(exclude_none=True)) + logger.debug("OAuth client info saved for %s", self._server_name) + + # -- cleanup ----------------------------------------------------------- + + def remove(self) -> None: + """Delete all stored OAuth state for this server.""" + for p in (self._tokens_path(), self._client_info_path()): + p.unlink(missing_ok=True) + + def has_cached_tokens(self) -> bool: + """Return True if we have tokens on disk (may be expired).""" + return self._tokens_path().exists() + + +# --------------------------------------------------------------------------- +# Callback handler factory -- each invocation gets its own result dict +# --------------------------------------------------------------------------- + + +def _make_callback_handler() -> tuple[type, dict]: + """Create a per-flow callback HTTP handler class with its own result dict. + + Returns ``(HandlerClass, result_dict)`` where *result_dict* is a mutable + dict that the handler writes ``auth_code`` and ``state`` into when the + OAuth redirect arrives. Each call returns a fresh pair so concurrent + flows don't stomp on each other. + """ + result: dict[str, Any] = {"auth_code": None, "state": None, "error": None} + + class _Handler(BaseHTTPRequestHandler): + def do_GET(self) -> None: # noqa: N802 + params = parse_qs(urlparse(self.path).query) + code = params.get("code", [None])[0] + state = params.get("state", [None])[0] + error = params.get("error", [None])[0] + + result["auth_code"] = code + result["state"] = state + result["error"] = error + + body = ( + "

Authorization Successful

" + "

You can close this tab and return to Hermes.

" + ) if code else ( + "

Authorization Failed

" + f"

Error: {error or 'unknown'}

" + ) + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write(body.encode()) + + def log_message(self, fmt: str, *args: Any) -> None: + logger.debug("OAuth callback: %s", fmt % args) + + return _Handler, result + + +# --------------------------------------------------------------------------- +# Async redirect + callback handlers for OAuthClientProvider +# --------------------------------------------------------------------------- + + +async def _redirect_handler(authorization_url: str) -> None: + """Show the authorization URL to the user. + + Opens the browser automatically when possible; always prints the URL + as a fallback for headless/SSH/gateway environments. + """ + msg = ( + f"\n MCP OAuth: authorization required.\n" + f" Open this URL in your browser:\n\n" + f" {authorization_url}\n" + ) + print(msg, file=sys.stderr) + + if _can_open_browser(): + try: + opened = webbrowser.open(authorization_url) + if opened: + print(" (Browser opened automatically.)\n", file=sys.stderr) + else: + print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr) + except Exception: + print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr) + else: + print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr) + + +async def _wait_for_callback() -> tuple[str, str | None]: + """Wait for the OAuth callback to arrive on the local callback server. + + Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth`` + before this is ever called. Polls for the result without blocking the + event loop. + + Raises: + OAuthNonInteractiveError: If the callback times out (no user present + to complete the browser auth). + """ + global _oauth_port + assert _oauth_port is not None, "OAuth callback port not set" + + # The callback server is already running (started in build_oauth_auth). + # We just need to poll for the result. + handler_cls, result = _make_callback_handler() + + # Start a temporary server on the known port + try: + server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls) + except OSError: + # Port already in use — the server from build_oauth_auth is running. + # Fall back to polling the server started by build_oauth_auth. + raise OAuthNonInteractiveError( + "OAuth callback timed out — could not bind callback port. " + "Complete the authorization in a browser first, then retry." + ) + + server_thread = threading.Thread(target=server.handle_request, daemon=True) + server_thread.start() + + timeout = 300.0 + poll_interval = 0.5 + elapsed = 0.0 + while elapsed < timeout: + if result["auth_code"] is not None or result["error"] is not None: + break + await asyncio.sleep(poll_interval) + elapsed += poll_interval + + server.server_close() + + if result["error"]: + raise RuntimeError(f"OAuth authorization failed: {result['error']}") + if result["auth_code"] is None: + raise OAuthNonInteractiveError( + "OAuth callback timed out — no authorization code received. " + "Ensure you completed the browser authorization flow." + ) + + return result["auth_code"], result["state"] # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def build_oauth_auth(server_name: str, server_url: str): - """Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE. - - Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery, - registration, PKCE, token exchange, and refresh automatically. - - In non-interactive environments (no TTY), this still returns a provider - so that **cached tokens and refresh flows work**. Only the interactive - authorization-code grant will fail fast with a clear error instead of - blocking the event loop. - - Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``), - or ``None`` if the MCP SDK auth module is not available. - """ - try: - from mcp.client.auth import OAuthClientProvider - from mcp.shared.auth import OAuthClientMetadata - except ImportError: - logger.warning("MCP SDK auth module not available — OAuth disabled") - return None - - storage = HermesTokenStorage(server_name) - interactive = _is_interactive() - - if not interactive: - # Check whether cached tokens exist. If they do, the SDK can still - # use them (and refresh them) without any user interaction. If not, - # we still build the provider — the callback_handler will raise - # OAuthNonInteractiveError if a fresh authorization is actually - # needed, which surfaces as a clean connection failure for this - # server only (other MCP servers are unaffected). - has_cached = storage._read_json(storage._tokens_path()) is not None - if not has_cached: - logger.warning( - "MCP server '%s' requires OAuth but no cached tokens found " - "and environment is non-interactive. The server will fail to " - "connect. Run 'hermes mcp auth %s' to authorize interactively.", - server_name, server_name, - ) - - global _oauth_port - _oauth_port = _find_free_port() - redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback" - - client_metadata = OAuthClientMetadata( - client_name="Hermes Agent", - redirect_uris=[redirect_uri], - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - scope="openid profile email offline_access", - token_endpoint_auth_method="none", - ) - - # In non-interactive mode, the redirect handler logs the URL and the - # callback handler raises immediately — no blocking, no input(). - redirect_handler = _redirect_to_browser - callback_handler = _wait_for_callback - - if not interactive: - async def _noninteractive_redirect(auth_url: str) -> None: - logger.warning( - "MCP server '%s' needs OAuth authorization (non-interactive, " - "cannot open browser). URL: %s", - server_name, auth_url, - ) - - async def _noninteractive_callback() -> tuple[str, str | None]: - raise OAuthNonInteractiveError( - f"MCP server '{server_name}' requires interactive OAuth " - f"authorization but the environment is non-interactive " - f"(no TTY). Run 'hermes mcp auth {server_name}' to " - f"authorize, then restart." - ) - - redirect_handler = _noninteractive_redirect - callback_handler = _noninteractive_callback - - return OAuthClientProvider( - server_url=server_url, - client_metadata=client_metadata, - storage=storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - timeout=120.0, - ) - def remove_oauth_tokens(server_name: str) -> None: """Delete stored OAuth tokens and client info for a server.""" - HermesTokenStorage(server_name).remove() + storage = HermesTokenStorage(server_name) + storage.remove() + logger.info("OAuth tokens removed for '%s'", server_name) + + +def build_oauth_auth( + server_name: str, + server_url: str, + oauth_config: dict | None = None, +) -> "OAuthClientProvider | None": + """Build an ``httpx.Auth``-compatible OAuth handler for an MCP server. + + Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config. + + Args: + server_name: Server key in mcp_servers config (used for storage). + server_url: MCP server endpoint URL. + oauth_config: Optional dict from the ``oauth:`` block in config.yaml. + + Returns: + An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks + OAuth support. + """ + if not _OAUTH_AVAILABLE: + logger.warning( + "MCP OAuth requested for '%s' but SDK auth types are not available. " + "Install with: pip install 'mcp>=1.10.0'", + server_name, + ) + return None + + global _oauth_port + + cfg = oauth_config or {} + + # --- Storage --- + storage = HermesTokenStorage(server_name) + + # --- Non-interactive warning --- + if not _is_interactive() and not storage.has_cached_tokens(): + logger.warning( + "MCP OAuth for '%s': non-interactive environment and no cached tokens found. " + "The OAuth flow requires browser authorization. Run interactively first " + "to complete the initial authorization, then cached tokens will be reused.", + server_name, + ) + + # --- Pick callback port --- + redirect_port = int(cfg.get("redirect_port", 0)) + if redirect_port == 0: + redirect_port = _find_free_port() + _oauth_port = redirect_port + + # --- Client metadata --- + client_name = cfg.get("client_name", "Hermes Agent") + scope = cfg.get("scope") + redirect_uri = f"http://127.0.0.1:{redirect_port}/callback" + + metadata_kwargs: dict[str, Any] = { + "client_name": client_name, + "redirect_uris": [AnyUrl(redirect_uri)], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + } + if scope: + metadata_kwargs["scope"] = scope + + client_secret = cfg.get("client_secret") + if client_secret: + metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post" + + client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs) + + # --- Pre-registered client --- + client_id = cfg.get("client_id") + if client_id: + info_dict: dict[str, Any] = { + "client_id": client_id, + "redirect_uris": [redirect_uri], + "grant_types": client_metadata.grant_types, + "response_types": client_metadata.response_types, + "token_endpoint_auth_method": client_metadata.token_endpoint_auth_method, + } + if client_secret: + info_dict["client_secret"] = client_secret + if client_name: + info_dict["client_name"] = client_name + if scope: + info_dict["scope"] = scope + + client_info = OAuthClientInformationFull.model_validate(info_dict) + _write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True)) + logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name) + + # --- Base URL for discovery --- + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # --- Build provider --- + provider = OAuthClientProvider( + server_url=base_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=_redirect_handler, + callback_handler=_wait_for_callback, + timeout=float(cfg.get("timeout", 300)), + ) + + return provider diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2e1b9217f..5e4101a93 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -892,7 +892,9 @@ class MCPServerTask: if self._auth_type == "oauth": try: from tools.mcp_oauth import build_oauth_auth - _oauth_auth = build_oauth_auth(self.name, url) + _oauth_auth = build_oauth_auth( + self.name, url, config.get("oauth") + ) except Exception as exc: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) raise