hermes-agent/tests/tools/test_mcp_client_cert.py
kshitijk4poor 66827f8947 chore: prune unused imports and duplicate import redefinitions
Remove unused imports (F401) and duplicate/shadowed import
redefinitions (F811) across the codebase using ruff's safe
autofixes. No behavioral changes -- imports only.

- ~1400 safe autofixes applied across 644 files (net -1072 lines)
- __init__.py re-exports preserved (excluded from F401 removal so
  public re-export surfaces stay intact)
- Re-exports that are imported or monkeypatched by tests but look
  unused in their defining module are kept with explicit # noqa:
  F401 (gateway/run.py load_dotenv; run_agent re-exports from
  agent.message_sanitization, agent.context_compressor,
  agent.retry_utils, agent.prompt_builder, agent.process_bootstrap,
  agent.codex_responses_adapter)
- Unsafe F841 (unused-variable) fixes deliberately skipped -- those
  can change behavior when the RHS has side effects
- ruff lints remain disabled in pyproject.toml (only PLW1514 is
  selected); this is a one-time cleanup, not a config change

Verification:
- python -m compileall: clean
- pytest --collect-only: all 27161 tests collect (zero import errors)
- core entry points import clean (run_agent, model_tools, cli,
  toolsets, hermes_state, batch_runner, gateway)
- static scan: every name any test imports directly from an edited
  module still resolves
2026-05-28 22:26:25 -07:00

521 lines
18 KiB
Python

"""Tests for mTLS client certificate config on MCP HTTP/SSE transports.
Covers:
1. ``_resolve_client_cert`` helper — string, tuple, encrypted-key, validation
errors, missing-file errors.
2. HTTP (new SDK ``streamable_http_client``) path forwards ``cert=`` into the
user-owned ``httpx.AsyncClient``.
3. SSE path forwards ``cert`` and ``ssl_verify`` via an ``httpx_client_factory``
without breaking the OAuth/headers/timeout passthrough.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# _resolve_client_cert helper
# ---------------------------------------------------------------------------
class TestResolveClientCert:
def test_returns_none_when_unset(self):
from tools.mcp_tool import _resolve_client_cert
assert _resolve_client_cert("srv", {}) is None
assert _resolve_client_cert("srv", {"url": "https://x"}) is None
def test_string_form_single_pem(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
pem = tmp_path / "combined.pem"
pem.write_text("dummy")
result = _resolve_client_cert("srv", {"client_cert": str(pem)})
assert result == str(pem)
def test_string_cert_with_separate_key(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
result = _resolve_client_cert("srv", {
"client_cert": str(cert),
"client_key": str(key),
})
assert result == (str(cert), str(key))
def test_list_form_two_elements(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
result = _resolve_client_cert("srv", {
"client_cert": [str(cert), str(key)],
})
assert result == (str(cert), str(key))
def test_list_form_with_passphrase(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
result = _resolve_client_cert("srv", {
"client_cert": [str(cert), str(key), "passphrase"],
})
assert result == (str(cert), str(key), "passphrase")
def test_tilde_expansion(self, tmp_path, monkeypatch):
from tools.mcp_tool import _resolve_client_cert
monkeypatch.setenv("HOME", str(tmp_path))
pem = tmp_path / "client.pem"
pem.write_text("dummy")
result = _resolve_client_cert("srv", {"client_cert": "~/client.pem"})
assert result == str(pem)
def test_missing_file_raises(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
with pytest.raises(FileNotFoundError, match=r"srv.*client_cert.*not found"):
_resolve_client_cert("srv", {
"client_cert": str(tmp_path / "nope.pem"),
})
def test_missing_key_file_raises(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
cert.write_text("cert")
with pytest.raises(FileNotFoundError, match=r"srv.*client_key.*not found"):
_resolve_client_cert("srv", {
"client_cert": str(cert),
"client_key": str(tmp_path / "missing.key"),
})
def test_list_with_bad_length_raises(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
with pytest.raises(ValueError, match=r"list form must have 2 or 3"):
_resolve_client_cert("srv", {"client_cert": [str(tmp_path / "x")]})
def test_list_plus_client_key_rejected(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
with pytest.raises(ValueError, match=r"either client_cert as a list"):
_resolve_client_cert("srv", {
"client_cert": [str(cert), str(key)],
"client_key": str(key),
})
def test_non_string_path_rejected(self):
from tools.mcp_tool import _resolve_client_cert
with pytest.raises(ValueError, match=r"client_cert must be a non-empty string"):
_resolve_client_cert("srv", {"client_cert": 123})
def test_password_must_be_string(self, tmp_path):
from tools.mcp_tool import _resolve_client_cert
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
with pytest.raises(ValueError, match=r"key passphrase.*must be a string"):
_resolve_client_cert("srv", {
"client_cert": [str(cert), str(key), 42],
})
# ---------------------------------------------------------------------------
# HTTP transport — cert forwarded into httpx.AsyncClient
# ---------------------------------------------------------------------------
class TestHTTPClientCert:
def test_cert_forwarded_to_async_client(self, tmp_path):
"""When client_cert is set, the new-SDK HTTP path passes ``cert=``
into ``httpx.AsyncClient``."""
from tools.mcp_tool import MCPServerTask
cert = tmp_path / "client.pem"
cert.write_text("dummy")
server = MCPServerTask("remote")
captured: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, *a):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def initialize(self):
return None
async def _discover_tools(self):
self._shutdown_event.set()
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client",
return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http({
"url": "https://example.com/mcp",
"client_cert": str(cert),
})
asyncio.run(_drive())
assert captured.get("cert") == str(cert)
def test_cert_tuple_forwarded(self, tmp_path):
"""List/tuple form resolves to a tuple in ``cert=``."""
from tools.mcp_tool import MCPServerTask
cert = tmp_path / "client.crt"
key = tmp_path / "client.key"
cert.write_text("cert")
key.write_text("key")
server = MCPServerTask("remote")
captured: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, *a):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def initialize(self):
return None
async def _discover_tools(self):
self._shutdown_event.set()
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client",
return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http({
"url": "https://example.com/mcp",
"client_cert": [str(cert), str(key)],
})
asyncio.run(_drive())
assert captured.get("cert") == (str(cert), str(key))
def test_no_cert_means_no_cert_kwarg(self):
"""When client_cert is unset, ``cert`` is not passed to ``httpx.AsyncClient``
(matches SDK defaults)."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
captured: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
class DummyTransportCtx:
async def __aenter__(self):
return MagicMock(), MagicMock(), (lambda: None)
async def __aexit__(self, *a):
return False
class DummySession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def initialize(self):
return None
async def _discover_tools(self):
self._shutdown_event.set()
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
patch("httpx.AsyncClient", DummyAsyncClient), \
patch("tools.mcp_tool.streamable_http_client",
return_value=DummyTransportCtx()), \
patch("tools.mcp_tool.ClientSession", DummySession), \
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
await server._run_http({"url": "https://example.com/mcp"})
asyncio.run(_drive())
assert "cert" not in captured
def test_missing_cert_file_surfaces_clear_error(self, tmp_path):
"""A missing cert file fails fast with a server-scoped error message."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
async def _drive():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
patch("tools.mcp_tool._MCP_NEW_HTTP", True):
await server._run_http({
"url": "https://example.com/mcp",
"client_cert": str(tmp_path / "nope.pem"),
})
with pytest.raises(FileNotFoundError, match=r"remote.*client_cert.*not found"):
asyncio.run(_drive())
# ---------------------------------------------------------------------------
# SSE transport — cert + verify routed via httpx_client_factory
# ---------------------------------------------------------------------------
@pytest.fixture
def patch_sse_client():
"""Replace ``sse_client`` with a MagicMock that records its kwargs.
Returns the captured kwargs dict so tests can assert how ``_run_http``
called it.
"""
captured_kwargs: dict = {}
class _FakeStream:
def __init__(self):
self._read = AsyncMock()
self._write = AsyncMock()
async def __aenter__(self):
return (self._read, self._write)
async def __aexit__(self, *a):
return False
def fake_sse_client(**kwargs):
captured_kwargs.clear()
captured_kwargs.update(kwargs)
return _FakeStream()
class _FakeSession:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
return mock_session
async def __aexit__(self, *a):
return False
with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \
patch("tools.mcp_tool.ClientSession", new=_FakeSession):
yield captured_kwargs
class TestSSEClientCert:
def test_no_factory_when_defaults(self, patch_sse_client):
"""With no cert and ssl_verify=True (default), the SDK's own factory is
used — we don't inject one."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("sse-test")
server._auth_type = ""
server._sampling = None
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
assert "httpx_client_factory" not in patch_sse_client
def test_factory_injected_when_cert_set(self, patch_sse_client, tmp_path):
"""With client_cert set, an httpx_client_factory is injected that
applies the cert (and follow_redirects=True to match the SDK)."""
from tools.mcp_tool import MCPServerTask
cert = tmp_path / "client.pem"
cert.write_text("dummy")
server = MCPServerTask("sse-test")
server._auth_type = ""
server._sampling = None
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"client_cert": str(cert),
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
factory = patch_sse_client.get("httpx_client_factory")
assert factory is not None, "expected httpx_client_factory to be injected"
# Invoke the factory the way the SDK would; capture the resulting
# httpx.AsyncClient kwargs.
captured_client_kwargs: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured_client_kwargs.update(kwargs)
import httpx
with patch.object(httpx, "AsyncClient", DummyAsyncClient):
factory(headers={"x": "y"}, timeout=httpx.Timeout(30.0), auth=None)
assert captured_client_kwargs["cert"] == str(cert)
assert captured_client_kwargs["verify"] is True
assert captured_client_kwargs["follow_redirects"] is True
assert captured_client_kwargs["headers"] == {"x": "y"}
def test_factory_forwards_custom_ca_bundle(self, patch_sse_client, tmp_path):
"""ssl_verify as a path is forwarded to the factory's httpx client."""
from tools.mcp_tool import MCPServerTask
ca_bundle = tmp_path / "ca.pem"
ca_bundle.write_text("dummy")
server = MCPServerTask("sse-test")
server._auth_type = ""
server._sampling = None
async def drive():
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
new=AsyncMock(return_value="shutdown")), \
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
try:
await asyncio.wait_for(
server._run_http({
"url": "https://example.com/mcp/sse",
"transport": "sse",
"ssl_verify": str(ca_bundle),
}),
timeout=2.0,
)
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
pass
asyncio.run(drive())
factory = patch_sse_client.get("httpx_client_factory")
assert factory is not None
captured_client_kwargs: dict = {}
class DummyAsyncClient:
def __init__(self, **kwargs):
captured_client_kwargs.update(kwargs)
import httpx
with patch.object(httpx, "AsyncClient", DummyAsyncClient):
factory(headers=None, timeout=None, auth=None)
assert captured_client_kwargs["verify"] == str(ca_bundle)
assert "cert" not in captured_client_kwargs