feat(mcp): add HTTP transport, reconnection, security hardening

Upgrades the MCP client implementation from PR #291 with:

- HTTP/Streamable HTTP transport: support 'url' key in config for remote
  MCP servers (Notion, Slack, Sentry, Supabase, etc.)
- Automatic reconnection with exponential backoff (1s-60s, 5 retries)
  when a server connection drops unexpectedly
- Environment variable filtering: only pass safe vars (PATH, HOME, etc.)
  plus user-specified env to stdio subprocesses (prevents secret leaks)
- Credential stripping: sanitize error messages before returning to the
  LLM (strips GitHub PATs, OpenAI keys, Bearer tokens, etc.)
- Configurable per-server timeouts: 'timeout' and 'connect_timeout' keys
- Fix shutdown race condition in servers_snapshot variable scoping

Test coverage: 50 tests (up from 30), including new tests for env
filtering, credential sanitization, HTTP config detection, reconnection
logic, and configurable timeouts.

All 1162 tests pass (1162 passed, 3 skipped, 0 failed).
This commit is contained in:
teknium1 2026-03-02 18:40:03 -08:00
parent 468b7fdbad
commit 64ff8f065b
2 changed files with 611 additions and 65 deletions

View file

@ -2,9 +2,9 @@
"""
MCP (Model Context Protocol) Client Support
Connects to external MCP servers via stdio transport, discovers their tools,
and registers them into the hermes-agent tool registry so the agent can call
them like any built-in tool.
Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport,
discovers their tools, and registers them into the hermes-agent tool registry
so the agent can call them like any built-in tool.
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
The ``mcp`` Python package is optional -- if not installed, this module is a
@ -17,17 +17,32 @@ Example config::
command: "npx"
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
env: {}
timeout: 120 # per-tool-call timeout in seconds (default: 120)
connect_timeout: 60 # initial connection timeout (default: 60)
github:
command: "npx"
args: ["-y", "@modelcontextprotocol/server-github"]
env:
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
remote_api:
url: "https://my-mcp-server.example.com/mcp"
headers:
Authorization: "Bearer sk-..."
timeout: 180
Features:
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
- Automatic reconnection with exponential backoff (up to 5 retries)
- Environment variable filtering for stdio subprocesses (security)
- Credential stripping in error messages returned to the LLM
- Configurable per-server timeouts for tool calls and connections
- Thread-safe architecture with dedicated background event loop
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
its ``async with stdio_client(...)`` context alive. Tool call coroutines
are scheduled onto the loop via ``run_coroutine_threadsafe()``.
its transport context alive. Tool call coroutines are scheduled onto the
loop via ``run_coroutine_threadsafe()``.
On shutdown, each server Task is signalled to exit its ``async with``
block, ensuring the anyio cancel-scope cleanup happens in the *same*
@ -43,6 +58,8 @@ Thread safety:
import asyncio
import json
import logging
import os
import re
import threading
from typing import Any, Dict, List, Optional
@ -53,13 +70,81 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
_MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
try:
from mcp.client.streamable_http import streamablehttp_client
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_DISCOVERY_TIMEOUT = 60 # seconds for server discovery
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection
_MAX_RECONNECT_RETRIES = 5
_MAX_BACKOFF_SECONDS = 60
# Environment variables that are safe to pass to stdio subprocesses
_SAFE_ENV_KEYS = frozenset({
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
})
# Regex for credential patterns to strip from error messages
_CREDENTIAL_PATTERN = re.compile(
r"(?:"
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r")",
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Security helpers
# ---------------------------------------------------------------------------
def _build_safe_env(user_env: Optional[dict]) -> dict:
"""Build a filtered environment dict for stdio subprocesses.
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
variables from the current process environment, plus any variables
explicitly specified by the user in the server config.
This prevents accidentally leaking secrets like API keys, tokens, or
credentials to MCP server subprocesses.
"""
env = {}
for key, value in os.environ.items():
if key in _SAFE_ENV_KEYS or key.startswith("XDG_"):
env[key] = value
if user_env:
env.update(user_env)
return env
def _sanitize_error(text: str) -> str:
"""Strip credential-like patterns from error text before returning to LLM.
Replaces tokens, keys, and other secrets with [REDACTED] to prevent
accidental credential exposure in tool error responses.
"""
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
@ -70,66 +155,152 @@ class MCPServerTask:
The entire connection lifecycle (connect, discover, serve, disconnect)
runs inside one asyncio Task so that anyio cancel-scopes created by
``stdio_client`` are entered and exited in the same Task context.
the transport client are entered and exited in the same Task context.
Supports both stdio and HTTP/StreamableHTTP transports.
"""
__slots__ = (
"name", "session",
"_task", "_ready", "_shutdown_event", "_tools", "_error",
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
self._task: Optional[asyncio.Task] = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
self._config: dict = {}
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect."""
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
return "url" in self._config
async def _run_stdio(self, config: dict):
"""Run the server using stdio transport."""
command = config.get("command")
args = config.get("args", [])
env = config.get("env")
user_env = config.get("env")
if not command:
self._error = ValueError(
raise ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
self._ready.set()
return
safe_env = _build_safe_env(user_env)
server_params = StdioServerParameters(
command=command,
args=args,
env=env if env else None,
env=safe_env if safe_env else None,
)
try:
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
tools_result = await session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport."""
if not _MCP_HTTP_AVAILABLE:
raise ImportError(
f"MCP server '{self.name}' requires HTTP transport but "
"mcp.client.streamable_http is not available. "
"Upgrade the mcp package to get HTTP support."
)
# Signal that connection is ready
url = config["url"]
headers = config.get("headers")
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
async def _discover_tools(self):
"""Discover tools from the connected session."""
if self.session is None:
return
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
Includes automatic reconnection with exponential backoff if the
connection drops unexpectedly (unless shutdown was requested).
"""
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
retries = 0
backoff = 1.0
while True:
try:
if self._is_http():
await self._run_http(config)
else:
await self._run_stdio(config)
# Normal exit (shutdown requested) -- break out
break
except Exception as exc:
self.session = None
# If this is the first connection attempt, report the error
if not self._ready.is_set():
self._error = exc
self._ready.set()
return
# Block until shutdown is requested -- this keeps the
# async-with contexts alive on THIS Task.
await self._shutdown_event.wait()
except Exception as exc:
self._error = exc
self._ready.set()
finally:
self.session = None
# If shutdown was requested, don't reconnect
if self._shutdown_event.is_set():
logger.debug(
"MCP server '%s' disconnected during shutdown: %s",
self.name, exc,
)
return
retries += 1
if retries > _MAX_RECONNECT_RETRIES:
logger.warning(
"MCP server '%s' failed after %d reconnection attempts, "
"giving up: %s",
self.name, _MAX_RECONNECT_RETRIES, exc,
)
return
logger.warning(
"MCP server '%s' connection lost (attempt %d/%d), "
"reconnecting in %.0fs: %s",
self.name, retries, _MAX_RECONNECT_RETRIES,
backoff, exc,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
# Check again after sleeping
if self._shutdown_event.is_set():
return
finally:
self.session = None
async def start(self, config: dict):
"""Create the background Task and wait until ready (or failed)."""
@ -203,7 +374,10 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
def _load_mcp_config() -> Dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: {command, args, env}}`` or empty dict.
Returns a dict of ``{server_name: server_config}`` or empty dict.
Server config can contain either ``command``/``args``/``env`` for stdio
transport or ``url``/``headers`` for HTTP transport, plus optional
``timeout`` and ``connect_timeout`` overrides.
"""
try:
from hermes_cli.config import load_config
@ -224,11 +398,12 @@ def _load_mcp_config() -> Dict[str, dict]:
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
The server Task keeps the subprocess alive in the background.
The server Task keeps the connection alive in the background.
Call ``server.shutdown()`` (on the same event loop) to tear it down.
Raises:
ValueError: if ``command`` is missing from *config*.
ValueError: if required config keys are missing.
ImportError: if HTTP transport is needed but not available.
Exception: on connection or initialization failure.
"""
server = MCPServerTask(name)
@ -240,7 +415,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str):
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
"""Return a sync handler that calls an MCP tool via the background loop.
The handler conforms to the registry's dispatch interface:
@ -263,7 +438,11 @@ def _make_tool_handler(server_name: str, tool_name: str):
for block in (result.content or []):
if hasattr(block, "text"):
error_text += block.text
return json.dumps({"error": error_text or "MCP tool returned an error"})
return json.dumps({
"error": _sanitize_error(
error_text or "MCP tool returned an error"
)
})
# Collect text from content blocks
parts: List[str] = []
@ -273,10 +452,17 @@ def _make_tool_handler(server_name: str, tool_name: str):
return json.dumps({"result": "\n".join(parts) if parts else ""})
try:
return _run_on_mcp_loop(_call(), timeout=120)
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc)
return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"})
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return _handler
@ -339,7 +525,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
from tools.registry import registry
from toolsets import create_custom_toolset
server = await _connect_server(name, config)
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
server = await asyncio.wait_for(
_connect_server(name, config),
timeout=connect_timeout,
)
with _lock:
_servers[name] = server
@ -354,7 +544,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
name=tool_name_prefixed,
toolset=toolset_name,
schema=schema,
handler=_make_tool_handler(name, mcp_tool.name),
handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout),
check_fn=_make_check_fn(name),
is_async=False,
description=schema["description"],
@ -369,9 +559,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
tools=registered_names,
)
transport_type = "HTTP" if "url" in config else "stdio"
logger.info(
"MCP server '%s': registered %d tool(s): %s",
name, len(registered_names), ", ".join(registered_names),
"MCP server '%s' (%s): registered %d tool(s): %s",
name, transport_type, len(registered_names),
", ".join(registered_names),
)
return registered_names
@ -419,9 +611,12 @@ def discover_mcp_tools() -> List[str]:
registered = await _discover_and_register_server(name, cfg)
all_tools.extend(registered)
except Exception as exc:
logger.warning("Failed to connect to MCP server '%s': %s", name, exc)
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, exc,
)
_run_on_mcp_loop(_discover_all(), timeout=60)
_run_on_mcp_loop(_discover_all(), timeout=_DEFAULT_DISCOVERY_TIMEOUT)
if all_tools:
# Dynamically inject into all hermes-* platform toolsets
@ -444,15 +639,10 @@ def shutdown_mcp_servers():
All servers are shut down in parallel via ``asyncio.gather``.
"""
with _lock:
if not _servers:
# No servers -- just stop the loop. _stop_mcp_loop() also
# acquires _lock, so we must release it first.
pass
else:
servers_snapshot = list(_servers.values())
servers_snapshot = list(_servers.values())
# Fast path: nothing to shut down.
if not _servers:
if not servers_snapshot:
_stop_mcp_loop()
return