mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-30 06:41:51 +00:00
Remove unused imports (F401) and duplicate/shadowed import redefinitions (F811) across the codebase using ruff's safe autofixes. No behavioral changes -- imports only. - ~1400 safe autofixes applied across 644 files (net -1072 lines) - __init__.py re-exports preserved (excluded from F401 removal so public re-export surfaces stay intact) - Re-exports that are imported or monkeypatched by tests but look unused in their defining module are kept with explicit # noqa: F401 (gateway/run.py load_dotenv; run_agent re-exports from agent.message_sanitization, agent.context_compressor, agent.retry_utils, agent.prompt_builder, agent.process_bootstrap, agent.codex_responses_adapter) - Unsafe F841 (unused-variable) fixes deliberately skipped -- those can change behavior when the RHS has side effects - ruff lints remain disabled in pyproject.toml (only PLW1514 is selected); this is a one-time cleanup, not a config change Verification: - python -m compileall: clean - pytest --collect-only: all 27161 tests collect (zero import errors) - core entry points import clean (run_agent, model_tools, cli, toolsets, hermes_state, batch_runner, gateway) - static scan: every name any test imports directly from an edited module still resolves
521 lines
18 KiB
Python
521 lines
18 KiB
Python
"""Tests for mTLS client certificate config on MCP HTTP/SSE transports.
|
|
|
|
Covers:
|
|
|
|
1. ``_resolve_client_cert`` helper — string, tuple, encrypted-key, validation
|
|
errors, missing-file errors.
|
|
|
|
2. HTTP (new SDK ``streamable_http_client``) path forwards ``cert=`` into the
|
|
user-owned ``httpx.AsyncClient``.
|
|
|
|
3. SSE path forwards ``cert`` and ``ssl_verify`` via an ``httpx_client_factory``
|
|
without breaking the OAuth/headers/timeout passthrough.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _resolve_client_cert helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResolveClientCert:
|
|
def test_returns_none_when_unset(self):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
assert _resolve_client_cert("srv", {}) is None
|
|
assert _resolve_client_cert("srv", {"url": "https://x"}) is None
|
|
|
|
def test_string_form_single_pem(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
pem = tmp_path / "combined.pem"
|
|
pem.write_text("dummy")
|
|
|
|
result = _resolve_client_cert("srv", {"client_cert": str(pem)})
|
|
assert result == str(pem)
|
|
|
|
def test_string_cert_with_separate_key(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
cert = tmp_path / "client.crt"
|
|
key = tmp_path / "client.key"
|
|
cert.write_text("cert")
|
|
key.write_text("key")
|
|
|
|
result = _resolve_client_cert("srv", {
|
|
"client_cert": str(cert),
|
|
"client_key": str(key),
|
|
})
|
|
assert result == (str(cert), str(key))
|
|
|
|
def test_list_form_two_elements(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
cert = tmp_path / "client.crt"
|
|
key = tmp_path / "client.key"
|
|
cert.write_text("cert")
|
|
key.write_text("key")
|
|
|
|
result = _resolve_client_cert("srv", {
|
|
"client_cert": [str(cert), str(key)],
|
|
})
|
|
assert result == (str(cert), str(key))
|
|
|
|
def test_list_form_with_passphrase(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
cert = tmp_path / "client.crt"
|
|
key = tmp_path / "client.key"
|
|
cert.write_text("cert")
|
|
key.write_text("key")
|
|
|
|
result = _resolve_client_cert("srv", {
|
|
"client_cert": [str(cert), str(key), "passphrase"],
|
|
})
|
|
assert result == (str(cert), str(key), "passphrase")
|
|
|
|
def test_tilde_expansion(self, tmp_path, monkeypatch):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
monkeypatch.setenv("HOME", str(tmp_path))
|
|
pem = tmp_path / "client.pem"
|
|
pem.write_text("dummy")
|
|
|
|
result = _resolve_client_cert("srv", {"client_cert": "~/client.pem"})
|
|
assert result == str(pem)
|
|
|
|
def test_missing_file_raises(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
with pytest.raises(FileNotFoundError, match=r"srv.*client_cert.*not found"):
|
|
_resolve_client_cert("srv", {
|
|
"client_cert": str(tmp_path / "nope.pem"),
|
|
})
|
|
|
|
def test_missing_key_file_raises(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
cert = tmp_path / "client.crt"
|
|
cert.write_text("cert")
|
|
|
|
with pytest.raises(FileNotFoundError, match=r"srv.*client_key.*not found"):
|
|
_resolve_client_cert("srv", {
|
|
"client_cert": str(cert),
|
|
"client_key": str(tmp_path / "missing.key"),
|
|
})
|
|
|
|
def test_list_with_bad_length_raises(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
with pytest.raises(ValueError, match=r"list form must have 2 or 3"):
|
|
_resolve_client_cert("srv", {"client_cert": [str(tmp_path / "x")]})
|
|
|
|
def test_list_plus_client_key_rejected(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
cert = tmp_path / "client.crt"
|
|
key = tmp_path / "client.key"
|
|
cert.write_text("cert")
|
|
key.write_text("key")
|
|
|
|
with pytest.raises(ValueError, match=r"either client_cert as a list"):
|
|
_resolve_client_cert("srv", {
|
|
"client_cert": [str(cert), str(key)],
|
|
"client_key": str(key),
|
|
})
|
|
|
|
def test_non_string_path_rejected(self):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
with pytest.raises(ValueError, match=r"client_cert must be a non-empty string"):
|
|
_resolve_client_cert("srv", {"client_cert": 123})
|
|
|
|
def test_password_must_be_string(self, tmp_path):
|
|
from tools.mcp_tool import _resolve_client_cert
|
|
|
|
cert = tmp_path / "client.crt"
|
|
key = tmp_path / "client.key"
|
|
cert.write_text("cert")
|
|
key.write_text("key")
|
|
|
|
with pytest.raises(ValueError, match=r"key passphrase.*must be a string"):
|
|
_resolve_client_cert("srv", {
|
|
"client_cert": [str(cert), str(key), 42],
|
|
})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HTTP transport — cert forwarded into httpx.AsyncClient
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHTTPClientCert:
|
|
def test_cert_forwarded_to_async_client(self, tmp_path):
|
|
"""When client_cert is set, the new-SDK HTTP path passes ``cert=``
|
|
into ``httpx.AsyncClient``."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
cert = tmp_path / "client.pem"
|
|
cert.write_text("dummy")
|
|
|
|
server = MCPServerTask("remote")
|
|
captured: dict = {}
|
|
|
|
class DummyAsyncClient:
|
|
def __init__(self, **kwargs):
|
|
captured.update(kwargs)
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
class DummyTransportCtx:
|
|
async def __aenter__(self):
|
|
return MagicMock(), MagicMock(), (lambda: None)
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
class DummySession:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
async def initialize(self):
|
|
return None
|
|
|
|
async def _discover_tools(self):
|
|
self._shutdown_event.set()
|
|
|
|
async def _drive():
|
|
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
|
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
|
|
patch("httpx.AsyncClient", DummyAsyncClient), \
|
|
patch("tools.mcp_tool.streamable_http_client",
|
|
return_value=DummyTransportCtx()), \
|
|
patch("tools.mcp_tool.ClientSession", DummySession), \
|
|
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
|
|
await server._run_http({
|
|
"url": "https://example.com/mcp",
|
|
"client_cert": str(cert),
|
|
})
|
|
|
|
asyncio.run(_drive())
|
|
assert captured.get("cert") == str(cert)
|
|
|
|
def test_cert_tuple_forwarded(self, tmp_path):
|
|
"""List/tuple form resolves to a tuple in ``cert=``."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
cert = tmp_path / "client.crt"
|
|
key = tmp_path / "client.key"
|
|
cert.write_text("cert")
|
|
key.write_text("key")
|
|
|
|
server = MCPServerTask("remote")
|
|
captured: dict = {}
|
|
|
|
class DummyAsyncClient:
|
|
def __init__(self, **kwargs):
|
|
captured.update(kwargs)
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
class DummyTransportCtx:
|
|
async def __aenter__(self):
|
|
return MagicMock(), MagicMock(), (lambda: None)
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
class DummySession:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
async def initialize(self):
|
|
return None
|
|
|
|
async def _discover_tools(self):
|
|
self._shutdown_event.set()
|
|
|
|
async def _drive():
|
|
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
|
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
|
|
patch("httpx.AsyncClient", DummyAsyncClient), \
|
|
patch("tools.mcp_tool.streamable_http_client",
|
|
return_value=DummyTransportCtx()), \
|
|
patch("tools.mcp_tool.ClientSession", DummySession), \
|
|
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
|
|
await server._run_http({
|
|
"url": "https://example.com/mcp",
|
|
"client_cert": [str(cert), str(key)],
|
|
})
|
|
|
|
asyncio.run(_drive())
|
|
assert captured.get("cert") == (str(cert), str(key))
|
|
|
|
def test_no_cert_means_no_cert_kwarg(self):
|
|
"""When client_cert is unset, ``cert`` is not passed to ``httpx.AsyncClient``
|
|
(matches SDK defaults)."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
server = MCPServerTask("remote")
|
|
captured: dict = {}
|
|
|
|
class DummyAsyncClient:
|
|
def __init__(self, **kwargs):
|
|
captured.update(kwargs)
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
class DummyTransportCtx:
|
|
async def __aenter__(self):
|
|
return MagicMock(), MagicMock(), (lambda: None)
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
class DummySession:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
async def initialize(self):
|
|
return None
|
|
|
|
async def _discover_tools(self):
|
|
self._shutdown_event.set()
|
|
|
|
async def _drive():
|
|
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
|
patch("tools.mcp_tool._MCP_NEW_HTTP", True), \
|
|
patch("httpx.AsyncClient", DummyAsyncClient), \
|
|
patch("tools.mcp_tool.streamable_http_client",
|
|
return_value=DummyTransportCtx()), \
|
|
patch("tools.mcp_tool.ClientSession", DummySession), \
|
|
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
|
|
await server._run_http({"url": "https://example.com/mcp"})
|
|
|
|
asyncio.run(_drive())
|
|
assert "cert" not in captured
|
|
|
|
def test_missing_cert_file_surfaces_clear_error(self, tmp_path):
|
|
"""A missing cert file fails fast with a server-scoped error message."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
server = MCPServerTask("remote")
|
|
|
|
async def _drive():
|
|
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
|
patch("tools.mcp_tool._MCP_NEW_HTTP", True):
|
|
await server._run_http({
|
|
"url": "https://example.com/mcp",
|
|
"client_cert": str(tmp_path / "nope.pem"),
|
|
})
|
|
|
|
with pytest.raises(FileNotFoundError, match=r"remote.*client_cert.*not found"):
|
|
asyncio.run(_drive())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SSE transport — cert + verify routed via httpx_client_factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_sse_client():
|
|
"""Replace ``sse_client`` with a MagicMock that records its kwargs.
|
|
|
|
Returns the captured kwargs dict so tests can assert how ``_run_http``
|
|
called it.
|
|
"""
|
|
captured_kwargs: dict = {}
|
|
|
|
class _FakeStream:
|
|
def __init__(self):
|
|
self._read = AsyncMock()
|
|
self._write = AsyncMock()
|
|
|
|
async def __aenter__(self):
|
|
return (self._read, self._write)
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
def fake_sse_client(**kwargs):
|
|
captured_kwargs.clear()
|
|
captured_kwargs.update(kwargs)
|
|
return _FakeStream()
|
|
|
|
class _FakeSession:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
mock_session = MagicMock()
|
|
mock_session.initialize = AsyncMock()
|
|
return mock_session
|
|
|
|
async def __aexit__(self, *a):
|
|
return False
|
|
|
|
with patch("tools.mcp_tool.sse_client", new=fake_sse_client), \
|
|
patch("tools.mcp_tool.ClientSession", new=_FakeSession):
|
|
yield captured_kwargs
|
|
|
|
|
|
class TestSSEClientCert:
|
|
def test_no_factory_when_defaults(self, patch_sse_client):
|
|
"""With no cert and ssl_verify=True (default), the SDK's own factory is
|
|
used — we don't inject one."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
server = MCPServerTask("sse-test")
|
|
server._auth_type = ""
|
|
server._sampling = None
|
|
|
|
async def drive():
|
|
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
|
new=AsyncMock(return_value="shutdown")), \
|
|
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
|
|
try:
|
|
await asyncio.wait_for(
|
|
server._run_http({
|
|
"url": "https://example.com/mcp/sse",
|
|
"transport": "sse",
|
|
}),
|
|
timeout=2.0,
|
|
)
|
|
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
|
pass
|
|
|
|
asyncio.run(drive())
|
|
assert "httpx_client_factory" not in patch_sse_client
|
|
|
|
def test_factory_injected_when_cert_set(self, patch_sse_client, tmp_path):
|
|
"""With client_cert set, an httpx_client_factory is injected that
|
|
applies the cert (and follow_redirects=True to match the SDK)."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
cert = tmp_path / "client.pem"
|
|
cert.write_text("dummy")
|
|
|
|
server = MCPServerTask("sse-test")
|
|
server._auth_type = ""
|
|
server._sampling = None
|
|
|
|
async def drive():
|
|
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
|
new=AsyncMock(return_value="shutdown")), \
|
|
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
|
|
try:
|
|
await asyncio.wait_for(
|
|
server._run_http({
|
|
"url": "https://example.com/mcp/sse",
|
|
"transport": "sse",
|
|
"client_cert": str(cert),
|
|
}),
|
|
timeout=2.0,
|
|
)
|
|
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
|
pass
|
|
|
|
asyncio.run(drive())
|
|
|
|
factory = patch_sse_client.get("httpx_client_factory")
|
|
assert factory is not None, "expected httpx_client_factory to be injected"
|
|
|
|
# Invoke the factory the way the SDK would; capture the resulting
|
|
# httpx.AsyncClient kwargs.
|
|
captured_client_kwargs: dict = {}
|
|
|
|
class DummyAsyncClient:
|
|
def __init__(self, **kwargs):
|
|
captured_client_kwargs.update(kwargs)
|
|
|
|
import httpx
|
|
with patch.object(httpx, "AsyncClient", DummyAsyncClient):
|
|
factory(headers={"x": "y"}, timeout=httpx.Timeout(30.0), auth=None)
|
|
|
|
assert captured_client_kwargs["cert"] == str(cert)
|
|
assert captured_client_kwargs["verify"] is True
|
|
assert captured_client_kwargs["follow_redirects"] is True
|
|
assert captured_client_kwargs["headers"] == {"x": "y"}
|
|
|
|
def test_factory_forwards_custom_ca_bundle(self, patch_sse_client, tmp_path):
|
|
"""ssl_verify as a path is forwarded to the factory's httpx client."""
|
|
from tools.mcp_tool import MCPServerTask
|
|
|
|
ca_bundle = tmp_path / "ca.pem"
|
|
ca_bundle.write_text("dummy")
|
|
|
|
server = MCPServerTask("sse-test")
|
|
server._auth_type = ""
|
|
server._sampling = None
|
|
|
|
async def drive():
|
|
with patch.object(MCPServerTask, "_wait_for_lifecycle_event",
|
|
new=AsyncMock(return_value="shutdown")), \
|
|
patch.object(MCPServerTask, "_discover_tools", new=AsyncMock()):
|
|
try:
|
|
await asyncio.wait_for(
|
|
server._run_http({
|
|
"url": "https://example.com/mcp/sse",
|
|
"transport": "sse",
|
|
"ssl_verify": str(ca_bundle),
|
|
}),
|
|
timeout=2.0,
|
|
)
|
|
except (asyncio.TimeoutError, StopAsyncIteration, Exception):
|
|
pass
|
|
|
|
asyncio.run(drive())
|
|
|
|
factory = patch_sse_client.get("httpx_client_factory")
|
|
assert factory is not None
|
|
|
|
captured_client_kwargs: dict = {}
|
|
|
|
class DummyAsyncClient:
|
|
def __init__(self, **kwargs):
|
|
captured_client_kwargs.update(kwargs)
|
|
|
|
import httpx
|
|
with patch.object(httpx, "AsyncClient", DummyAsyncClient):
|
|
factory(headers=None, timeout=None, auth=None)
|
|
|
|
assert captured_client_kwargs["verify"] == str(ca_bundle)
|
|
assert "cert" not in captured_client_kwargs
|