hermes-agent/tests/tools/test_mcp_client_cert.py
Teknium 87e5b2fae0
feat(mcp): support TLS client certificates (mTLS) for HTTP and SSE servers (#33721)
Adds first-class `client_cert` / `client_key` config keys so MCP servers
behind mTLS work without an external TLS-terminating proxy. Resolves
inbound community question (Jeremy W.).

Schema (per `mcp_servers.<name>`, HTTP/SSE only):

- `client_cert: "/path/to/combined.pem"` — single PEM with cert + key
- `client_cert: "/path/to/cert"` + `client_key: "/path/to/key"` — separate
- `client_cert: [cert, key]` or `[cert, key, password]` — list form,
  with optional passphrase for encrypted keys

Paths support `~` expansion. Missing files raise a server-scoped
`FileNotFoundError` at connect time rather than failing later with an
opaque TLS handshake error.

Wiring:

- New SDK HTTP path (mcp >= 1.24): `cert=` on the user-owned
  `httpx.AsyncClient` alongside the existing `verify=` handling.
- SSE path: routed through an `httpx_client_factory` that wraps the
  SDK's defaults (follow_redirects=True) and layers `verify` + `cert`
  on top. The factory is only injected when needed, so the SDK's
  built-in `create_mcp_http_client` keeps being used in the default
  case.
- Deprecated mcp<1.24 path left untouched — that SDK's
  `streamablehttp_client` signature doesn't expose `cert`, and adding
  it would be dead code.

Also documents the previously-undocumented `ssl_verify` key (bool or
CA bundle path) in the MCP config reference.

Tests:

- `tests/tools/test_mcp_client_cert.py` (new, 19 tests):
  - `_resolve_client_cert` helper: all three input forms, `~` expansion,
    missing-file and validation errors.
  - HTTP transport: `cert=` forwarded into `httpx.AsyncClient` for
    string and tuple forms; absent when unset; missing-file error
    propagates.
  - SSE transport: factory only injected when cert or non-default
    verify is set; factory applies cert, custom CA bundle, and
    preserves `follow_redirects=True` + forwarded headers/auth.
- Existing tests: 200/200 in `test_mcp_tool.py` + `test_mcp_sse_transport.py`
  still pass.
2026-05-28 00:55:55 -07:00

522 lines
18 KiB
Python

"""Tests for mTLS client certificate config on MCP HTTP/SSE transports.
Covers:
1. ``_resolve_client_cert`` helper — string, tuple, encrypted-key, validation
errors, missing-file errors.
2. HTTP (new SDK ``streamable_http_client``) path forwards ``cert=`` into the
user-owned ``httpx.AsyncClient``.
3. SSE path forwards ``cert`` and ``ssl_verify`` via an ``httpx_client_factory``
without breaking the OAuth/headers/timeout passthrough.
"""
from __future__ import annotations
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# _resolve_client_cert helper
# ---------------------------------------------------------------------------
class TestResolveClientCert:
def test_returns_none_when_unset(self):
from tools.mcp_tool import _resolve_client_cert
assert _resolve_client_cert("srv", {}) is None
assert _resolve_client_cert("srv", {"url": "https://x"}) is None
def test_string_form_single_pem(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
pem = tmp_path / "combined.pem"
pem.write_text("dummy")
result = _resolve_client_cert("srv", {"client_cert": str(pem)})
assert result == str(pem)
def test_string_cert_with_separate_key(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
result = _resolve_client_cert("srv", {
"client_cert": str(cert),
"client_key": str(key),
})
assert result == (str(cert), str(key))
def test_list_form_two_elements(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
result = _resolve_client_cert("srv", {
"client_cert": [str(cert), str(key)],
})
assert result == (str(cert), str(key))
def test_list_form_with_passphrase(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
result = _resolve_client_cert("srv", {
"client_cert": [str(cert), str(key), "passphrase"],
})
assert result == (str(cert), str(key), "passphrase")
def test_tilde_expansion(self, tmp_path, monkeypatch):
from tools.mcp_tool import _resolve_client_cert
monkeypatch.setenv("HOME", str(tmp_path))
pem = tmp_path / "client.pem"
pem.write_text("dummy")
result = _resolve_client_cert("srv", {"client_cert": "~/client.pem"})
assert result == str(pem)
def test_missing_file_raises(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
with pytest.raises(FileNotFoundError, match=r"srv.*client_cert.*not found"):
_resolve_client_cert("srv", {
"client_cert": str(tmp_path / "nope.pem"),
})
def test_missing_key_file_raises(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
cert.write_text("cert")
with pytest.raises(FileNotFoundError, match=r"srv.*client_key.*not found"):
_resolve_client_cert("srv", {
"client_cert": str(cert),
"client_key": str(tmp_path / "missing.key"),
})
def test_list_with_bad_length_raises(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
with pytest.raises(ValueError, match=r"list form must have 2 or 3"):
_resolve_client_cert("srv", {"client_cert": [str(tmp_path / "x")]})
def test_list_plus_client_key_rejected(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
with pytest.raises(ValueError, match=r"either client_cert as a list"):
_resolve_client_cert("srv", {
"client_cert": [str(cert), str(key)],
"client_key": str(key),
})
def test_non_string_path_rejected(self):
from tools.mcp_tool import _resolve_client_cert
with pytest.raises(ValueError, match=r"client_cert must be a non-empty string"):
_resolve_client_cert("srv", {"client_cert": 123})
def test_password_must_be_string(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
with pytest.raises(ValueError, match=r"key passphrase.*must be a string"):
_resolve_client_cert("srv", {
"client_cert": [str(cert), str(key), 42],
})
# ---------------------------------------------------------------------------
# HTTP transport — cert forwarded into httpx.AsyncClient
# ---------------------------------------------------------------------------
class TestHTTPClientCert:
def test_cert_forwarded_to_async_client(self, tmp_path):
"""When client_cert is set, the new-SDK HTTP path passes ``cert=``
into ``httpx.AsyncClient``."""
from tools.mcp_tool import MCPServerTask
cert = tmp_path / "client.pem"
cert.write_text("dummy")
server = MCPServerTask("remote")
captured: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, *a):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def initialize(self):
return None
async def _discover_tools(self):
self._shutdown_event.set()
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client",
return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http({
"url": "https://example.com/mcp",
"client_cert": str(cert),
})
asyncio.run(_drive())
assert captured.get("cert") == str(cert)
def test_cert_tuple_forwarded(self, tmp_path):
"""List/tuple form resolves to a tuple in ``cert=``."""
from tools.mcp_tool import MCPServerTask
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
server = MCPServerTask("remote")
captured: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, *a):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def initialize(self):
return None
async def _discover_tools(self):
self._shutdown_event.set()
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client",
return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http({
"url": "https://example.com/mcp",
"client_cert": [str(cert), str(key)],
})
asyncio.run(_drive())
assert captured.get("cert") == (str(cert), str(key))
def test_no_cert_means_no_cert_kwarg(self):
"""When client_cert is unset, ``cert`` is not passed to ``httpx.AsyncClient``
(matches SDK defaults)."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
captured: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, *a):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def initialize(self):
return None
async def _discover_tools(self):
self._shutdown_event.set()
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client",
return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http({"url": "https://example.com/mcp"})
asyncio.run(_drive())
assert "cert" not in captured
def test_missing_cert_file_surfaces_clear_error(self, tmp_path):
"""A missing cert file fails fast with a server-scoped error message."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True):
await server._run_http({
"url": "https://example.com/mcp",
"client_cert": str(tmp_path / "nope.pem"),
})
with pytest.raises(FileNotFoundError, match=r"remote.*client_cert.*not found"):
asyncio.run(_drive())
# ---------------------------------------------------------------------------
# SSE transport — cert + verify routed via httpx_client_factory
# ---------------------------------------------------------------------------
@pytest.fixture
def patch_sse_client():
"""Replace ``sse_client`` with a MagicMock that records its kwargs.
Returns the captured kwargs dict so tests can assert how ``_run_http``
called it.
"""
captured_kwargs: dict = {}
class _FakeStream:
def __init__(self):
self._read = AsyncMock()
self._write = AsyncMock()
async def __aenter__(self):
return (self._read, self._write)
async def __aexit__(self, *a):
return False
def fake_sse_client(**kwargs):
captured_kwargs.clear()
captured_kwargs.update(kwargs)
return _FakeStream()
class _FakeSession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
return mock_session
async def __aexit__(self, *a):
return False
with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \
patch("tools.mcp_tool.ClientSession", new=_FakeSession):
yield captured_kwargs
class TestSSEClientCert:
def test_no_factory_when_defaults(self, patch_sse_client):
"""With no cert and ssl_verify=True (default), the SDK's own factory is
used — we don't inject one."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("sse-test")
server._auth_type = ""
server._sampling = None
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
assert "httpx_client_factory" not in patch_sse_client
def test_factory_injected_when_cert_set(self, patch_sse_client, tmp_path):
"""With client_cert set, an httpx_client_factory is injected that
applies the cert (and follow_redirects=True to match the SDK)."""
from tools.mcp_tool import MCPServerTask
cert = tmp_path / "client.pem"
cert.write_text("dummy")
server = MCPServerTask("sse-test")
server._auth_type = ""
server._sampling = None
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"client_cert": str(cert),
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
factory = patch_sse_client.get("httpx_client_factory")
assert factory is not None, "expected httpx_client_factory to be injected"
# Invoke the factory the way the SDK would; capture the resulting
# httpx.AsyncClient kwargs.
captured_client_kwargs: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured_client_kwargs.update(kwargs)
import httpx
with patch.object(httpx, "AsyncClient", DummyAsyncClient):
factory(headers={"x": "y"}, timeout=httpx.Timeout(30.0), auth=None)
assert captured_client_kwargs["cert"] == str(cert)
assert captured_client_kwargs["verify"] is True
assert captured_client_kwargs["follow_redirects"] is True
assert captured_client_kwargs["headers"] == {"x": "y"}
def test_factory_forwards_custom_ca_bundle(self, patch_sse_client, tmp_path):
"""ssl_verify as a path is forwarded to the factory's httpx client."""
from tools.mcp_tool import MCPServerTask
ca_bundle = tmp_path / "ca.pem"
ca_bundle.write_text("dummy")
server = MCPServerTask("sse-test")
server._auth_type = ""
server._sampling = None
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"ssl_verify": str(ca_bundle),
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
factory = patch_sse_client.get("httpx_client_factory")
assert factory is not None
captured_client_kwargs: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured_client_kwargs.update(kwargs)
import httpx
with patch.object(httpx, "AsyncClient", DummyAsyncClient):
factory(headers=None, timeout=None, auth=None)
assert captured_client_kwargs["verify"] == str(ca_bundle)
assert "cert" not in captured_client_kwargs