diff --git a/tests/tools/test_mcp_oauth.py b/tests/tools/test_mcp_oauth.py index 19c588e58..8643c26b3 100644 --- a/tests/tools/test_mcp_oauth.py +++ b/tests/tools/test_mcp_oauth.py @@ -1,7 +1,8 @@ -"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK.""" +"""Tests for tools/mcp_oauth.py — OAuth 2.1 PKCE support for MCP servers.""" import json import os +from io import BytesIO from pathlib import Path from unittest.mock import patch, MagicMock, AsyncMock @@ -16,6 +17,7 @@ from tools.mcp_oauth import ( _can_open_browser, _is_interactive, _wait_for_callback, + _make_callback_handler, ) @@ -79,34 +81,93 @@ class TestHermesTokenStorage: assert not (d / "test-server.json").exists() assert not (d / "test-server.client.json").exists() + def test_has_cached_tokens(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("my-server") + + assert not storage.has_cached_tokens() + + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "my-server.json").write_text('{"access_token": "x", "token_type": "Bearer"}') + + assert storage.has_cached_tokens() + + def test_corrupt_tokens_returns_none(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("bad-server") + + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "bad-server.json").write_text("NOT VALID JSON{{{") + + import asyncio + assert asyncio.run(storage.get_tokens()) is None + + def test_corrupt_client_info_returns_none(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("bad-server") + + d = tmp_path / "mcp-tokens" + d.mkdir(parents=True) + (d / "bad-server.client.json").write_text("GARBAGE") + + import asyncio + assert asyncio.run(storage.get_client_info()) is None + # --------------------------------------------------------------------------- # build_oauth_auth # --------------------------------------------------------------------------- class TestBuildOAuthAuth: - def test_returns_oauth_provider(self): + def test_returns_oauth_provider(self, tmp_path, monkeypatch): try: from mcp.client.auth import OAuthClientProvider except ImportError: pytest.skip("MCP SDK auth not available") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) auth = build_oauth_auth("test", "https://example.com/mcp") assert isinstance(auth, OAuthClientProvider) def test_returns_none_without_sdk(self, monkeypatch): import tools.mcp_oauth as mod - orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + monkeypatch.setattr(mod, "_OAUTH_AVAILABLE", False) + result = build_oauth_auth("test", "https://example.com") + assert result is None - def _block_import(name, *args, **kwargs): - if "mcp.client.auth" in name: - raise ImportError("blocked") - return orig_import(name, *args, **kwargs) + def test_pre_registered_client_id_stored(self, tmp_path, monkeypatch): + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") - with patch("builtins.__import__", side_effect=_block_import): - result = build_oauth_auth("test", "https://example.com") - # May or may not be None depending on import caching, but shouldn't crash - assert result is None or result is not None + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + build_oauth_auth("slack", "https://slack.example.com/mcp", { + "client_id": "my-app-id", + "client_secret": "my-secret", + "scope": "channels:read", + }) + + client_path = tmp_path / "mcp-tokens" / "slack.client.json" + assert client_path.exists() + data = json.loads(client_path.read_text()) + assert data["client_id"] == "my-app-id" + assert data["client_secret"] == "my-secret" + + def test_scope_passed_through(self, tmp_path, monkeypatch): + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + provider = build_oauth_auth("scoped", "https://example.com/mcp", { + "scope": "read write admin", + }) + assert provider is not None + assert provider.context.client_metadata.scope == "read write admin" # --------------------------------------------------------------------------- @@ -119,6 +180,12 @@ class TestUtilities: assert isinstance(port, int) assert 1024 <= port <= 65535 + def test_find_free_port_unique(self): + """Two consecutive calls should return different ports (usually).""" + ports = {_find_free_port() for _ in range(5)} + # At least 2 different ports out of 5 attempts + assert len(ports) >= 2 + def test_can_open_browser_false_in_ssh(self, monkeypatch): monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22") assert _can_open_browser() is False @@ -127,14 +194,22 @@ class TestUtilities: monkeypatch.delenv("SSH_CLIENT", raising=False) monkeypatch.delenv("SSH_TTY", raising=False) monkeypatch.delenv("DISPLAY", raising=False) + monkeypatch.delenv("WAYLAND_DISPLAY", raising=False) # Mock os.name and uname for non-macOS, non-Windows monkeypatch.setattr(os, "name", "posix") monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})()) assert _can_open_browser() is False + def test_can_open_browser_true_with_display(self, monkeypatch): + monkeypatch.delenv("SSH_CLIENT", raising=False) + monkeypatch.delenv("SSH_TTY", raising=False) + monkeypatch.setenv("DISPLAY", ":0") + monkeypatch.setattr(os, "name", "posix") + assert _can_open_browser() is True + # --------------------------------------------------------------------------- -# remove_oauth_tokens +# Path traversal protection # --------------------------------------------------------------------------- class TestPathTraversal: @@ -169,11 +244,14 @@ class TestPathTraversal: assert "/" not in path.stem +# --------------------------------------------------------------------------- +# Callback handler isolation +# --------------------------------------------------------------------------- + class TestCallbackHandlerIsolation: """Verify concurrent OAuth flows don't share state.""" def test_independent_result_dicts(self): - from tools.mcp_oauth import _make_callback_handler _, result_a = _make_callback_handler() _, result_b = _make_callback_handler() @@ -184,10 +262,6 @@ class TestCallbackHandlerIsolation: assert result_b["auth_code"] == "code_B" def test_handler_writes_to_own_result(self): - from tools.mcp_oauth import _make_callback_handler - from io import BytesIO - from unittest.mock import MagicMock - HandlerClass, result = _make_callback_handler() assert result["auth_code"] is None @@ -203,13 +277,30 @@ class TestCallbackHandlerIsolation: assert result["auth_code"] == "test123" assert result["state"] == "mystate" + def test_handler_captures_error(self): + HandlerClass, result = _make_callback_handler() + + handler = HandlerClass.__new__(HandlerClass) + handler.path = "/callback?error=access_denied" + handler.wfile = BytesIO() + handler.send_response = MagicMock() + handler.send_header = MagicMock() + handler.end_headers = MagicMock() + handler.do_GET() + + assert result["auth_code"] is None + assert result["error"] == "access_denied" + + +# --------------------------------------------------------------------------- +# Port sharing +# --------------------------------------------------------------------------- class TestOAuthPortSharing: """Verify build_oauth_auth and _wait_for_callback use the same port.""" - def test_port_stored_globally(self): + def test_port_stored_globally(self, tmp_path, monkeypatch): import tools.mcp_oauth as mod - # Reset mod._oauth_port = None try: @@ -217,12 +308,17 @@ class TestOAuthPortSharing: except ImportError: pytest.skip("MCP SDK auth not available") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) build_oauth_auth("test-port", "https://example.com/mcp") assert mod._oauth_port is not None assert isinstance(mod._oauth_port, int) assert 1024 <= mod._oauth_port <= 65535 +# --------------------------------------------------------------------------- +# remove_oauth_tokens +# --------------------------------------------------------------------------- + class TestRemoveOAuthTokens: def test_removes_files(self, tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) @@ -242,7 +338,7 @@ class TestRemoveOAuthTokens: # --------------------------------------------------------------------------- -# Non-interactive / startup-safety tests (issue #4462) +# Non-interactive / startup-safety tests # --------------------------------------------------------------------------- class TestIsInteractive: diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index b614826a8..00172f340 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -1,326 +1,482 @@ -"""Thin OAuth adapter for MCP HTTP servers. - -Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements -``httpx.Auth``) with Hermes-specific token storage and browser-based -authorization. The SDK handles all of the heavy lifting: PKCE generation, -metadata discovery, dynamic client registration, token exchange, and refresh. - -Startup safety: - The callback handler never calls blocking ``input()`` on the event loop. - In non-interactive environments (no TTY, SSH, headless), the OAuth flow - raises ``OAuthNonInteractiveError`` instead of blocking, so that the - server degrades gracefully and other MCP servers are not affected. - -Usage in mcp_tool.py:: - - from tools.mcp_oauth import build_oauth_auth - auth = build_oauth_auth(server_name, server_url) - # pass ``auth`` as the httpx auth parameter +#!/usr/bin/env python3 """ +MCP OAuth 2.1 Client Support -from __future__ import annotations +Implements the browser-based OAuth 2.1 authorization code flow with PKCE +for MCP servers that require OAuth authentication instead of static bearer +tokens. + +Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass) +which handles discovery, dynamic client registration, PKCE, token exchange, +refresh, and step-up authorization automatically. + +This module provides the glue: + - ``HermesTokenStorage``: persists tokens/client-info to disk so they + survive across process restarts. + - Callback server: ephemeral localhost HTTP server to capture the OAuth + redirect with the authorization code. + - ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires + everything together and returns the ``httpx.Auth`` object. + +Configuration in config.yaml:: + + mcp_servers: + my_server: + url: "https://mcp.example.com/mcp" + auth: oauth + oauth: # all fields optional + client_id: "pre-registered-id" # skip dynamic registration + client_secret: "secret" # confidential clients only + scope: "read write" # default: server-provided + redirect_port: 0 # 0 = auto-pick free port + client_name: "My Custom Client" # default: "Hermes Agent" +""" import asyncio import json import logging import os +import re import socket import sys import threading import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Any +from typing import Any, Optional from urllib.parse import parse_qs, urlparse logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Lazy imports -- MCP SDK with OAuth support is optional +# --------------------------------------------------------------------------- + +_OAUTH_AVAILABLE = False +try: + from mcp.client.auth import OAuthClientProvider, TokenStorage + from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthToken, + ) + from pydantic import AnyUrl + + _OAUTH_AVAILABLE = True +except ImportError: + logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled") + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + class OAuthNonInteractiveError(RuntimeError): - """Raised when OAuth requires user interaction but the environment is non-interactive.""" - pass - -_TOKEN_DIR_NAME = "mcp-tokens" + """Raised when OAuth requires browser interaction in a non-interactive env.""" # --------------------------------------------------------------------------- -# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/ +# Module-level state # --------------------------------------------------------------------------- -def _sanitize_server_name(name: str) -> str: - """Sanitize server name for safe use as a filename.""" - import re - clean = re.sub(r"[^\w\-]", "-", name.strip().lower()) - clean = re.sub(r"-+", "-", clean).strip("-") - return clean[:60] or "unnamed" - - -class HermesTokenStorage: - """File-backed token storage implementing the MCP SDK's TokenStorage protocol.""" - - def __init__(self, server_name: str): - self._server_name = _sanitize_server_name(server_name) - - def _base_dir(self) -> Path: - home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) - d = home / _TOKEN_DIR_NAME - d.mkdir(parents=True, exist_ok=True) - return d - - def _tokens_path(self) -> Path: - return self._base_dir() / f"{self._server_name}.json" - - def _client_path(self) -> Path: - return self._base_dir() / f"{self._server_name}.client.json" - - # -- TokenStorage protocol (async) -- - - async def get_tokens(self): - data = self._read_json(self._tokens_path()) - if not data: - return None - try: - from mcp.shared.auth import OAuthToken - return OAuthToken(**data) - except Exception: - return None - - async def set_tokens(self, tokens) -> None: - self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) - - async def get_client_info(self): - data = self._read_json(self._client_path()) - if not data: - return None - try: - from mcp.shared.auth import OAuthClientInformationFull - return OAuthClientInformationFull(**data) - except Exception: - return None - - async def set_client_info(self, client_info) -> None: - self._write_json(self._client_path(), client_info.model_dump(exclude_none=True)) - - # -- helpers -- - - @staticmethod - def _read_json(path: Path) -> dict | None: - if not path.exists(): - return None - try: - return json.loads(path.read_text(encoding="utf-8")) - except Exception: - return None - - @staticmethod - def _write_json(path: Path, data: dict) -> None: - path.write_text(json.dumps(data, indent=2), encoding="utf-8") - try: - path.chmod(0o600) - except OSError: - pass - - def remove(self) -> None: - """Delete stored tokens and client info for this server.""" - for p in (self._tokens_path(), self._client_path()): - try: - p.unlink(missing_ok=True) - except OSError: - pass +# Port used by the most recent build_oauth_auth() call. Exposed so that +# tests can verify the callback server and the redirect_uri share a port. +_oauth_port: int | None = None # --------------------------------------------------------------------------- -# Browser-based callback handler +# Helpers # --------------------------------------------------------------------------- + +def _get_token_dir() -> Path: + """Return the directory for MCP OAuth token files. + + Uses HERMES_HOME so each profile gets its own OAuth tokens. + Layout: ``HERMES_HOME/mcp-tokens/`` + """ + try: + from hermes_constants import get_hermes_home + base = Path(get_hermes_home()) + except ImportError: + base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))) + return base / "mcp-tokens" + + +def _safe_filename(name: str) -> str: + """Sanitize a server name for use as a filename (no path separators).""" + return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default" + + def _find_free_port() -> int: + """Find an available TCP port on localhost.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] -def _make_callback_handler(): - """Create a callback handler class with instance-scoped result storage.""" - result = {"auth_code": None, "state": None} - - class Handler(BaseHTTPRequestHandler): - def do_GET(self): - qs = parse_qs(urlparse(self.path).query) - result["auth_code"] = (qs.get("code") or [None])[0] - result["state"] = (qs.get("state") or [None])[0] - self.send_response(200) - self.send_header("Content-Type", "text/html") - self.end_headers() - self.wfile.write(b"
You can close this tab and return to Hermes.
" + ) if code else ( + "Error: {error or 'unknown'}
" + ) + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write(body.encode()) + + def log_message(self, fmt: str, *args: Any) -> None: + logger.debug("OAuth callback: %s", fmt % args) + + return _Handler, result + + +# --------------------------------------------------------------------------- +# Async redirect + callback handlers for OAuthClientProvider +# --------------------------------------------------------------------------- + + +async def _redirect_handler(authorization_url: str) -> None: + """Show the authorization URL to the user. + + Opens the browser automatically when possible; always prints the URL + as a fallback for headless/SSH/gateway environments. + """ + msg = ( + f"\n MCP OAuth: authorization required.\n" + f" Open this URL in your browser:\n\n" + f" {authorization_url}\n" + ) + print(msg, file=sys.stderr) + + if _can_open_browser(): + try: + opened = webbrowser.open(authorization_url) + if opened: + print(" (Browser opened automatically.)\n", file=sys.stderr) + else: + print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr) + except Exception: + print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr) + else: + print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr) + + +async def _wait_for_callback() -> tuple[str, str | None]: + """Wait for the OAuth callback to arrive on the local callback server. + + Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth`` + before this is ever called. Polls for the result without blocking the + event loop. + + Raises: + OAuthNonInteractiveError: If the callback times out (no user present + to complete the browser auth). + """ + global _oauth_port + assert _oauth_port is not None, "OAuth callback port not set" + + # The callback server is already running (started in build_oauth_auth). + # We just need to poll for the result. + handler_cls, result = _make_callback_handler() + + # Start a temporary server on the known port + try: + server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls) + except OSError: + # Port already in use — the server from build_oauth_auth is running. + # Fall back to polling the server started by build_oauth_auth. + raise OAuthNonInteractiveError( + "OAuth callback timed out — could not bind callback port. " + "Complete the authorization in a browser first, then retry." + ) + + server_thread = threading.Thread(target=server.handle_request, daemon=True) + server_thread.start() + + timeout = 300.0 + poll_interval = 0.5 + elapsed = 0.0 + while elapsed < timeout: + if result["auth_code"] is not None or result["error"] is not None: + break + await asyncio.sleep(poll_interval) + elapsed += poll_interval + + server.server_close() + + if result["error"]: + raise RuntimeError(f"OAuth authorization failed: {result['error']}") + if result["auth_code"] is None: + raise OAuthNonInteractiveError( + "OAuth callback timed out — no authorization code received. " + "Ensure you completed the browser authorization flow." + ) + + return result["auth_code"], result["state"] # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def build_oauth_auth(server_name: str, server_url: str): - """Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE. - - Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery, - registration, PKCE, token exchange, and refresh automatically. - - In non-interactive environments (no TTY), this still returns a provider - so that **cached tokens and refresh flows work**. Only the interactive - authorization-code grant will fail fast with a clear error instead of - blocking the event loop. - - Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``), - or ``None`` if the MCP SDK auth module is not available. - """ - try: - from mcp.client.auth import OAuthClientProvider - from mcp.shared.auth import OAuthClientMetadata - except ImportError: - logger.warning("MCP SDK auth module not available — OAuth disabled") - return None - - storage = HermesTokenStorage(server_name) - interactive = _is_interactive() - - if not interactive: - # Check whether cached tokens exist. If they do, the SDK can still - # use them (and refresh them) without any user interaction. If not, - # we still build the provider — the callback_handler will raise - # OAuthNonInteractiveError if a fresh authorization is actually - # needed, which surfaces as a clean connection failure for this - # server only (other MCP servers are unaffected). - has_cached = storage._read_json(storage._tokens_path()) is not None - if not has_cached: - logger.warning( - "MCP server '%s' requires OAuth but no cached tokens found " - "and environment is non-interactive. The server will fail to " - "connect. Run 'hermes mcp auth %s' to authorize interactively.", - server_name, server_name, - ) - - global _oauth_port - _oauth_port = _find_free_port() - redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback" - - client_metadata = OAuthClientMetadata( - client_name="Hermes Agent", - redirect_uris=[redirect_uri], - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - scope="openid profile email offline_access", - token_endpoint_auth_method="none", - ) - - # In non-interactive mode, the redirect handler logs the URL and the - # callback handler raises immediately — no blocking, no input(). - redirect_handler = _redirect_to_browser - callback_handler = _wait_for_callback - - if not interactive: - async def _noninteractive_redirect(auth_url: str) -> None: - logger.warning( - "MCP server '%s' needs OAuth authorization (non-interactive, " - "cannot open browser). URL: %s", - server_name, auth_url, - ) - - async def _noninteractive_callback() -> tuple[str, str | None]: - raise OAuthNonInteractiveError( - f"MCP server '{server_name}' requires interactive OAuth " - f"authorization but the environment is non-interactive " - f"(no TTY). Run 'hermes mcp auth {server_name}' to " - f"authorize, then restart." - ) - - redirect_handler = _noninteractive_redirect - callback_handler = _noninteractive_callback - - return OAuthClientProvider( - server_url=server_url, - client_metadata=client_metadata, - storage=storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - timeout=120.0, - ) - def remove_oauth_tokens(server_name: str) -> None: """Delete stored OAuth tokens and client info for a server.""" - HermesTokenStorage(server_name).remove() + storage = HermesTokenStorage(server_name) + storage.remove() + logger.info("OAuth tokens removed for '%s'", server_name) + + +def build_oauth_auth( + server_name: str, + server_url: str, + oauth_config: dict | None = None, +) -> "OAuthClientProvider | None": + """Build an ``httpx.Auth``-compatible OAuth handler for an MCP server. + + Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config. + + Args: + server_name: Server key in mcp_servers config (used for storage). + server_url: MCP server endpoint URL. + oauth_config: Optional dict from the ``oauth:`` block in config.yaml. + + Returns: + An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks + OAuth support. + """ + if not _OAUTH_AVAILABLE: + logger.warning( + "MCP OAuth requested for '%s' but SDK auth types are not available. " + "Install with: pip install 'mcp>=1.10.0'", + server_name, + ) + return None + + global _oauth_port + + cfg = oauth_config or {} + + # --- Storage --- + storage = HermesTokenStorage(server_name) + + # --- Non-interactive warning --- + if not _is_interactive() and not storage.has_cached_tokens(): + logger.warning( + "MCP OAuth for '%s': non-interactive environment and no cached tokens found. " + "The OAuth flow requires browser authorization. Run interactively first " + "to complete the initial authorization, then cached tokens will be reused.", + server_name, + ) + + # --- Pick callback port --- + redirect_port = int(cfg.get("redirect_port", 0)) + if redirect_port == 0: + redirect_port = _find_free_port() + _oauth_port = redirect_port + + # --- Client metadata --- + client_name = cfg.get("client_name", "Hermes Agent") + scope = cfg.get("scope") + redirect_uri = f"http://127.0.0.1:{redirect_port}/callback" + + metadata_kwargs: dict[str, Any] = { + "client_name": client_name, + "redirect_uris": [AnyUrl(redirect_uri)], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + } + if scope: + metadata_kwargs["scope"] = scope + + client_secret = cfg.get("client_secret") + if client_secret: + metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post" + + client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs) + + # --- Pre-registered client --- + client_id = cfg.get("client_id") + if client_id: + info_dict: dict[str, Any] = { + "client_id": client_id, + "redirect_uris": [redirect_uri], + "grant_types": client_metadata.grant_types, + "response_types": client_metadata.response_types, + "token_endpoint_auth_method": client_metadata.token_endpoint_auth_method, + } + if client_secret: + info_dict["client_secret"] = client_secret + if client_name: + info_dict["client_name"] = client_name + if scope: + info_dict["scope"] = scope + + client_info = OAuthClientInformationFull.model_validate(info_dict) + _write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True)) + logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name) + + # --- Base URL for discovery --- + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # --- Build provider --- + provider = OAuthClientProvider( + server_url=base_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=_redirect_handler, + callback_handler=_wait_for_callback, + timeout=float(cfg.get("timeout", 300)), + ) + + return provider diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2e1b9217f..5e4101a93 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -892,7 +892,9 @@ class MCPServerTask: if self._auth_type == "oauth": try: from tools.mcp_oauth import build_oauth_auth - _oauth_auth = build_oauth_auth(self.name, url) + _oauth_auth = build_oauth_auth( + self.name, url, config.get("oauth") + ) except Exception as exc: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) raise