diff --git a/tests/tools/test_mcp_client_cert.py b/tests/tools/test_mcp_client_cert.py new file mode 100644 index 00000000000..67663414a23 --- /dev/null +++ b/tests/tools/test_mcp_client_cert.py @@ -0,0 +1,522 @@ +"""Tests for mTLS client certificate config on MCP HTTP/SSE transports. + +Covers: + +1. ``_resolve_client_cert`` helper — string, tuple, encrypted-key, validation + errors, missing-file errors. + +2. HTTP (new SDK ``streamable_http_client``) path forwards ``cert=`` into the + user-owned ``httpx.AsyncClient``. + +3. SSE path forwards ``cert`` and ``ssl_verify`` via an ``httpx_client_factory`` + without breaking the OAuth/headers/timeout passthrough. +""" + +from __future__ import annotations + +import asyncio +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# _resolve_client_cert helper +# --------------------------------------------------------------------------- + + +class TestResolveClientCert: + def test_returns_none_when_unset(self): + from tools.mcp_tool import _resolve_client_cert + + assert _resolve_client_cert("srv", {}) is None + assert _resolve_client_cert("srv", {"url": "https://x"}) is None + + def test_string_form_single_pem(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + pem = tmp_path / "combined.pem" + pem.write_text("dummy") + + result = _resolve_client_cert("srv", {"client_cert": str(pem)}) + assert result == str(pem) + + def test_string_cert_with_separate_key(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + cert = tmp_path / "client.crt" + key = tmp_path / "client.key" + cert.write_text("cert") + key.write_text("key") + + result = _resolve_client_cert("srv", { + "client_cert": str(cert), + "client_key": str(key), + }) + assert result == (str(cert), str(key)) + + def test_list_form_two_elements(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + cert = tmp_path / "client.crt" + key = tmp_path / "client.key" + cert.write_text("cert") + key.write_text("key") + + result = _resolve_client_cert("srv", { + "client_cert": [str(cert), str(key)], + }) + assert result == (str(cert), str(key)) + + def test_list_form_with_passphrase(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + cert = tmp_path / "client.crt" + key = tmp_path / "client.key" + cert.write_text("cert") + key.write_text("key") + + result = _resolve_client_cert("srv", { + "client_cert": [str(cert), str(key), "passphrase"], + }) + assert result == (str(cert), str(key), "passphrase") + + def test_tilde_expansion(self, tmp_path, monkeypatch): + from tools.mcp_tool import _resolve_client_cert + + monkeypatch.setenv("HOME", str(tmp_path)) + pem = tmp_path / "client.pem" + pem.write_text("dummy") + + result = _resolve_client_cert("srv", {"client_cert": "~/client.pem"}) + assert result == str(pem) + + def test_missing_file_raises(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + with pytest.raises(FileNotFoundError, match=r"srv.*client_cert.*not found"): + _resolve_client_cert("srv", { + "client_cert": str(tmp_path / "nope.pem"), + }) + + def test_missing_key_file_raises(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + cert = tmp_path / "client.crt" + cert.write_text("cert") + + with pytest.raises(FileNotFoundError, match=r"srv.*client_key.*not found"): + _resolve_client_cert("srv", { + "client_cert": str(cert), + "client_key": str(tmp_path / "missing.key"), + }) + + def test_list_with_bad_length_raises(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + with pytest.raises(ValueError, match=r"list form must have 2 or 3"): + _resolve_client_cert("srv", {"client_cert": [str(tmp_path / "x")]}) + + def test_list_plus_client_key_rejected(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + cert = tmp_path / "client.crt" + key = tmp_path / "client.key" + cert.write_text("cert") + key.write_text("key") + + with pytest.raises(ValueError, match=r"either client_cert as a list"): + _resolve_client_cert("srv", { + "client_cert": [str(cert), str(key)], + "client_key": str(key), + }) + + def test_non_string_path_rejected(self): + from tools.mcp_tool import _resolve_client_cert + + with pytest.raises(ValueError, match=r"client_cert must be a non-empty string"): + _resolve_client_cert("srv", {"client_cert": 123}) + + def test_password_must_be_string(self, tmp_path): + from tools.mcp_tool import _resolve_client_cert + + cert = tmp_path / "client.crt" + key = tmp_path / "client.key" + cert.write_text("cert") + key.write_text("key") + + with pytest.raises(ValueError, match=r"key passphrase.*must be a string"): + _resolve_client_cert("srv", { + "client_cert": [str(cert), str(key), 42], + }) + + +# --------------------------------------------------------------------------- +# HTTP transport — cert forwarded into httpx.AsyncClient +# --------------------------------------------------------------------------- + + +class TestHTTPClientCert: + def test_cert_forwarded_to_async_client(self, tmp_path): + """When client_cert is set, the new-SDK HTTP path passes ``cert=`` + into ``httpx.AsyncClient``.""" + from tools.mcp_tool import MCPServerTask + + cert = tmp_path / "client.pem" + cert.write_text("dummy") + + server = MCPServerTask("remote") + captured: dict = {} + + class DummyAsyncClient: + def __init__(self, **kwargs): + captured.update(kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class DummyTransportCtx: + async def __aenter__(self): + return MagicMock(), MagicMock(), (lambda: None) + + async def __aexit__(self, *a): + return False + + class DummySession: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def initialize(self): + return None + + async def _discover_tools(self): + self._shutdown_event.set() + + async def _drive(): + with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \ + patch("tools.mcp_tool._MCP_NEW_HTTP", True), \ + patch("httpx.AsyncClient", DummyAsyncClient), \ + patch("tools.mcp_tool.streamable_http_client", + return_value=DummyTransportCtx()), \ + patch("tools.mcp_tool.ClientSession", DummySession), \ + patch.object(MCPServerTask, "_discover_tools", _discover_tools): + await server._run_http({ + "url": "https://example.com/mcp", + "client_cert": str(cert), + }) + + asyncio.run(_drive()) + assert captured.get("cert") == str(cert) + + def test_cert_tuple_forwarded(self, tmp_path): + """List/tuple form resolves to a tuple in ``cert=``.""" + from tools.mcp_tool import MCPServerTask + + cert = tmp_path / "client.crt" + key = tmp_path / "client.key" + cert.write_text("cert") + key.write_text("key") + + server = MCPServerTask("remote") + captured: dict = {} + + class DummyAsyncClient: + def __init__(self, **kwargs): + captured.update(kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class DummyTransportCtx: + async def __aenter__(self): + return MagicMock(), MagicMock(), (lambda: None) + + async def __aexit__(self, *a): + return False + + class DummySession: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def initialize(self): + return None + + async def _discover_tools(self): + self._shutdown_event.set() + + async def _drive(): + with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \ + patch("tools.mcp_tool._MCP_NEW_HTTP", True), \ + patch("httpx.AsyncClient", DummyAsyncClient), \ + patch("tools.mcp_tool.streamable_http_client", + return_value=DummyTransportCtx()), \ + patch("tools.mcp_tool.ClientSession", DummySession), \ + patch.object(MCPServerTask, "_discover_tools", _discover_tools): + await server._run_http({ + "url": "https://example.com/mcp", + "client_cert": [str(cert), str(key)], + }) + + asyncio.run(_drive()) + assert captured.get("cert") == (str(cert), str(key)) + + def test_no_cert_means_no_cert_kwarg(self): + """When client_cert is unset, ``cert`` is not passed to ``httpx.AsyncClient`` + (matches SDK defaults).""" + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("remote") + captured: dict = {} + + class DummyAsyncClient: + def __init__(self, **kwargs): + captured.update(kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + class DummyTransportCtx: + async def __aenter__(self): + return MagicMock(), MagicMock(), (lambda: None) + + async def __aexit__(self, *a): + return False + + class DummySession: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def initialize(self): + return None + + async def _discover_tools(self): + self._shutdown_event.set() + + async def _drive(): + with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \ + patch("tools.mcp_tool._MCP_NEW_HTTP", True), \ + patch("httpx.AsyncClient", DummyAsyncClient), \ + patch("tools.mcp_tool.streamable_http_client", + return_value=DummyTransportCtx()), \ + patch("tools.mcp_tool.ClientSession", DummySession), \ + patch.object(MCPServerTask, "_discover_tools", _discover_tools): + await server._run_http({"url": "https://example.com/mcp"}) + + asyncio.run(_drive()) + assert "cert" not in captured + + def test_missing_cert_file_surfaces_clear_error(self, tmp_path): + """A missing cert file fails fast with a server-scoped error message.""" + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("remote") + + async def _drive(): + with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \ + patch("tools.mcp_tool._MCP_NEW_HTTP", True): + await server._run_http({ + "url": "https://example.com/mcp", + "client_cert": str(tmp_path / "nope.pem"), + }) + + with pytest.raises(FileNotFoundError, match=r"remote.*client_cert.*not found"): + asyncio.run(_drive()) + + +# --------------------------------------------------------------------------- +# SSE transport — cert + verify routed via httpx_client_factory +# --------------------------------------------------------------------------- + + +@pytest.fixture +def patch_sse_client(): + """Replace ``sse_client`` with a MagicMock that records its kwargs. + + Returns the captured kwargs dict so tests can assert how ``_run_http`` + called it. + """ + captured_kwargs: dict = {} + + class _FakeStream: + def __init__(self): + self._read = AsyncMock() + self._write = AsyncMock() + + async def __aenter__(self): + return (self._read, self._write) + + async def __aexit__(self, *a): + return False + + def fake_sse_client(**kwargs): + captured_kwargs.clear() + captured_kwargs.update(kwargs) + return _FakeStream() + + class _FakeSession: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + return mock_session + + async def __aexit__(self, *a): + return False + + with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \ + patch("tools.mcp_tool.ClientSession", new=_FakeSession): + yield captured_kwargs + + +class TestSSEClientCert: + def test_no_factory_when_defaults(self, patch_sse_client): + """With no cert and ssl_verify=True (default), the SDK's own factory is + used — we don't inject one.""" + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("sse-test") + server._auth_type = "" + server._sampling = None + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + assert "httpx_client_factory" not in patch_sse_client + + def test_factory_injected_when_cert_set(self, patch_sse_client, tmp_path): + """With client_cert set, an httpx_client_factory is injected that + applies the cert (and follow_redirects=True to match the SDK).""" + from tools.mcp_tool import MCPServerTask + + cert = tmp_path / "client.pem" + cert.write_text("dummy") + + server = MCPServerTask("sse-test") + server._auth_type = "" + server._sampling = None + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + "client_cert": str(cert), + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + + factory = patch_sse_client.get("httpx_client_factory") + assert factory is not None, "expected httpx_client_factory to be injected" + + # Invoke the factory the way the SDK would; capture the resulting + # httpx.AsyncClient kwargs. + captured_client_kwargs: dict = {} + + class DummyAsyncClient: + def __init__(self, **kwargs): + captured_client_kwargs.update(kwargs) + + import httpx + with patch.object(httpx, "AsyncClient", DummyAsyncClient): + factory(headers={"x": "y"}, timeout=httpx.Timeout(30.0), auth=None) + + assert captured_client_kwargs["cert"] == str(cert) + assert captured_client_kwargs["verify"] is True + assert captured_client_kwargs["follow_redirects"] is True + assert captured_client_kwargs["headers"] == {"x": "y"} + + def test_factory_forwards_custom_ca_bundle(self, patch_sse_client, tmp_path): + """ssl_verify as a path is forwarded to the factory's httpx client.""" + from tools.mcp_tool import MCPServerTask + + ca_bundle = tmp_path / "ca.pem" + ca_bundle.write_text("dummy") + + server = MCPServerTask("sse-test") + server._auth_type = "" + server._sampling = None + + async def drive(): + with patch.object(MCPServerTask, "_wait_for_lifecycle_event", + new=AsyncMock(return_value="shutdown")), \ + patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()): + try: + await asyncio.wait_for( + server._run_http({ + "url": "https://example.com/mcp/sse", + "transport": "sse", + "ssl_verify": str(ca_bundle), + }), + timeout=2.0, + ) + except (asyncio.TimeoutError, StopAsyncIteration, Exception): + pass + + asyncio.run(drive()) + + factory = patch_sse_client.get("httpx_client_factory") + assert factory is not None + + captured_client_kwargs: dict = {} + + class DummyAsyncClient: + def __init__(self, **kwargs): + captured_client_kwargs.update(kwargs) + + import httpx + with patch.object(httpx, "AsyncClient", DummyAsyncClient): + factory(headers=None, timeout=None, auth=None) + + assert captured_client_kwargs["verify"] == str(ca_bundle) + assert "cert" not in captured_client_kwargs diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 75c1c5e8633..157f79c1c52 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -559,6 +559,79 @@ def _validate_remote_mcp_url(server_name: str, url: Any) -> str: return stripped +def _resolve_client_cert(server_name: str, config: dict): + """Resolve the ``client_cert`` / ``client_key`` config for mTLS. + + Returns whatever ``httpx``'s ``cert=`` parameter accepts, or ``None`` when + no client certificate is configured: + + - ``None`` if neither ``client_cert`` nor ``client_key`` is set. + - A single absolute path string if ``client_cert`` is a string and + ``client_key`` is unset (PEM file with cert + key combined). + - A ``(cert_path, key_path)`` tuple when both are set, or when + ``client_cert`` is a 2-element list/tuple. + - A ``(cert_path, key_path, password)`` tuple when ``client_cert`` is + a 3-element list/tuple — the third element is the key passphrase. + + User paths support ``~`` expansion. Missing files raise ``FileNotFoundError`` + with a server-scoped message so the failure surfaces as a clear setup + error rather than an opaque TLS handshake error. + """ + raw_cert = config.get("client_cert") + raw_key = config.get("client_key") + + if raw_cert is None and raw_key is None: + return None + + def _expand(path: Any, label: str) -> str: + if not isinstance(path, str) or not path.strip(): + raise ValueError( + f"MCP server '{server_name}': {label} must be a non-empty " + f"string path (got {type(path).__name__})" + ) + expanded = os.path.expanduser(path.strip()) + if not os.path.isfile(expanded): + raise FileNotFoundError( + f"MCP server '{server_name}': {label} not found at " + f"{expanded!r}" + ) + return expanded + + # Tuple/list form for client_cert — (cert, key) or (cert, key, password). + if isinstance(raw_cert, (list, tuple)): + if raw_key is not None: + raise ValueError( + f"MCP server '{server_name}': specify either client_cert as " + f"a list [cert, key] OR client_cert + client_key, not both" + ) + if len(raw_cert) == 2: + cert_path = _expand(raw_cert[0], "client_cert[0]") + key_path = _expand(raw_cert[1], "client_cert[1]") + return (cert_path, key_path) + if len(raw_cert) == 3: + cert_path = _expand(raw_cert[0], "client_cert[0]") + key_path = _expand(raw_cert[1], "client_cert[1]") + password = raw_cert[2] + if not isinstance(password, str): + raise ValueError( + f"MCP server '{server_name}': client_cert[2] (key " + f"passphrase) must be a string" + ) + return (cert_path, key_path, password) + raise ValueError( + f"MCP server '{server_name}': client_cert list form must have 2 " + f"or 3 elements (got {len(raw_cert)})" + ) + + # String form for client_cert. + cert_path = _expand(raw_cert, "client_cert") + if raw_key is not None: + key_path = _expand(raw_key, "client_key") + return (cert_path, key_path) + # Single combined PEM file (cert + key in one file). + return cert_path + + def _format_connect_error(exc: BaseException) -> str: """Render nested MCP connection errors into an actionable short message.""" @@ -1362,6 +1435,7 @@ class MCPServerTask: headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) ssl_verify = config.get("ssl_verify", True) + client_cert = _resolve_client_cert(self.name, config) # OAuth 2.1 PKCE: route through the central MCPOAuthManager so the # same provider instance is reused across reconnects, pre-flow @@ -1413,6 +1487,37 @@ class MCPServerTask: # behind OAuth 2.1 PKCE work. Previously built but never # forwarded — SSE OAuth would silently fail with 401s. _sse_kwargs["auth"] = _oauth_auth + if client_cert is not None or ssl_verify is not True: + # SSE transport doesn't expose verify/cert as kwargs, so route + # them through an httpx_client_factory that wraps the SDK's + # defaults (follow_redirects=True) and adds our TLS settings. + # The SDK calls the factory with (headers, auth, timeout); we + # forward all of those and layer verify/cert on top. + import httpx as _httpx_mod + + _cert_for_factory = client_cert + _verify_for_factory = ssl_verify + + def _mcp_http_client_factory( + headers=None, timeout=None, auth=None, + ): + kwargs: dict = { + "follow_redirects": True, + "verify": _verify_for_factory, + } + if timeout is not None: + kwargs["timeout"] = timeout + else: + kwargs["timeout"] = _httpx_mod.Timeout(30.0, read=300.0) + if headers is not None: + kwargs["headers"] = headers + if auth is not None: + kwargs["auth"] = auth + if _cert_for_factory is not None: + kwargs["cert"] = _cert_for_factory + return _httpx_mod.AsyncClient(**kwargs) + + _sse_kwargs["httpx_client_factory"] = _mcp_http_client_factory async with sse_client(**_sse_kwargs) as (read_stream, write_stream): async with ClientSession( read_stream, write_stream, **sampling_kwargs @@ -1456,6 +1561,8 @@ class MCPServerTask: client_kwargs["headers"] = headers if _oauth_auth is not None: client_kwargs["auth"] = _oauth_auth + if client_cert is not None: + client_kwargs["cert"] = client_cert # Caller owns the client lifecycle — the SDK skips cleanup when # http_client is provided, so we wrap in async-with. diff --git a/website/docs/reference/mcp-config-reference.md b/website/docs/reference/mcp-config-reference.md index 86bbf78c61c..44d0d4512a9 100644 --- a/website/docs/reference/mcp-config-reference.md +++ b/website/docs/reference/mcp-config-reference.md @@ -25,6 +25,11 @@ mcp_servers: url: "..." # HTTP servers headers: {} + # Optional HTTP/SSE TLS settings: + ssl_verify: true # bool or path to a CA bundle (PEM) + client_cert: "/path/to/cert.pem" # mTLS client certificate (see below) + # client_key: "/path/to/key.pem" # optional, when key lives in a separate file + enabled: true timeout: 120 connect_timeout: 60 @@ -45,6 +50,9 @@ mcp_servers: | `env` | mapping | stdio | Environment passed to the subprocess | | `url` | string | HTTP | Remote MCP endpoint | | `headers` | mapping | HTTP | Headers for remote server requests | +| `ssl_verify` | bool or string | HTTP | TLS verification. `true` (default) uses system CAs, `false` disables verification (insecure), or a string path to a custom CA bundle (PEM) | +| `client_cert` | string or list | HTTP | mTLS client certificate. String = path to a PEM file containing cert + key. List `[cert, key]` = separate files. List `[cert, key, password]` = encrypted key | +| `client_key` | string | HTTP | Path to the client private key, when `client_cert` is a string and the key is in a separate file | | `enabled` | bool | both | Skip the server entirely when false | | `timeout` | number | both | Tool call timeout | | `connect_timeout` | number | both | Initial connection timeout | @@ -191,6 +199,40 @@ mcp_servers: prompts: false ``` +### TLS client certificate (mTLS) + +For HTTP/SSE servers that require a client certificate, set `client_cert` (and optionally `client_key`): + +```yaml +mcp_servers: + # Combined cert + key in a single PEM file + internal_api: + url: "https://mcp.internal.example.com/mcp" + client_cert: "~/secrets/mcp-client.pem" + + # Separate cert and key files + partner_api: + url: "https://mcp.partner.example.com/mcp" + client_cert: "~/secrets/client.crt" + client_key: "~/secrets/client.key" + + # Encrypted key with a passphrase (3-element list form) + bank_api: + url: "https://mcp.bank.example.com/mcp" + client_cert: ["~/secrets/client.crt", "~/secrets/client.key", "my-passphrase"] + + # Custom CA bundle (private CA / self-signed server) + lab_api: + url: "https://mcp.lab.local/mcp" + ssl_verify: "~/secrets/lab-ca.pem" + client_cert: "~/secrets/lab-client.pem" +``` + +Notes: +- Paths support `~` expansion. Missing files fail fast at connect time with a server-scoped error message. +- `ssl_verify: false` disables server certificate verification entirely. Don't use this with real services. +- Works on both Streamable HTTP and SSE transports. + ## Reloading config After changing MCP config, reload servers with: