From a27d7e68ccb2cff8d95ac68660a150dc565a413f Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 14 Jun 2026 04:46:54 -0700 Subject: [PATCH] fix(mcp): block suspicious stdio configs before probe (#46112) --- hermes_cli/config.py | 8 +-- hermes_cli/mcp_config.py | 10 +++ tests/hermes_cli/test_mcp_security.py | 93 +++++++++++++++++++++++++++ tools/mcp_tool.py | 53 +++++++++------ 4 files changed, 141 insertions(+), 23 deletions(-) diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 204e29d5faa..7bb0b283035 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -4837,15 +4837,15 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A raw_mcp_servers = config.get("mcp_servers") if isinstance(raw_mcp_servers, dict): try: - from hermes_cli.mcp_security import validate_mcp_server_entry + from hermes_cli.mcp_security import validate_mcp_server_entry as _validate_mcp_server_entry except Exception: - validate_mcp_server_entry = None - if validate_mcp_server_entry: + _validate_mcp_server_entry = None + if _validate_mcp_server_entry: mcp_touched = False for server_name, entry in raw_mcp_servers.items(): if not isinstance(entry, dict): continue - issues = validate_mcp_server_entry(server_name, entry) + issues = _validate_mcp_server_entry(server_name, entry) if not issues: continue entry["enabled"] = False diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index 94cd961ccc1..cad7ed43873 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -221,6 +221,10 @@ def _probe_single_server( Returns list of ``(tool_name, description)`` tuples. Raises on connection failure. """ + issues = validate_mcp_server_entry(name, config) + if issues: + raise ValueError("; ".join(issues)) + from tools.mcp_tool import ( _ensure_mcp_loop, _run_on_mcp_loop, @@ -352,6 +356,12 @@ def cmd_mcp_add(args): if explicit_env: server_config["env"] = explicit_env + issues = validate_mcp_server_entry(name, server_config) + if issues: + for issue in issues: + _warning(issue) + _warning(f"Server '{name}' was NOT saved due to suspicious configuration.") + return # ── Authentication ──────────────────────────────────────────────── diff --git a/tests/hermes_cli/test_mcp_security.py b/tests/hermes_cli/test_mcp_security.py index 2b0170847c6..a50d7e04ab0 100644 --- a/tests/hermes_cli/test_mcp_security.py +++ b/tests/hermes_cli/test_mcp_security.py @@ -2,6 +2,7 @@ from __future__ import annotations +from argparse import Namespace from pathlib import Path import pytest @@ -59,6 +60,51 @@ def test_save_mcp_server_rejects_dangerous_entry(tmp_path): assert "evil" not in load_config().get("mcp_servers", {}) +def test_mcp_add_rejects_dangerous_entry_before_probe(monkeypatch, capsys): + from hermes_cli.mcp_config import cmd_mcp_add + + probed = False + + def _probe_should_not_run(name, config): + nonlocal probed + probed = True + raise AssertionError("dangerous MCP config reached probe/spawn path") + + monkeypatch.setattr("hermes_cli.mcp_config._probe_single_server", _probe_should_not_run) + + cmd_mcp_add(Namespace( + name="evil", + url=None, + mcp_command="bash", + args=_dangerous_entry()["args"], + auth=None, + preset=None, + env=None, + )) + + out = capsys.readouterr().out + assert probed is False + assert "NOT saved" in out + + +def test_probe_rejects_dangerous_entry_before_connect(monkeypatch): + from hermes_cli.mcp_config import _probe_single_server + + connected = False + + async def _connect_should_not_run(name, config): + nonlocal connected + connected = True + raise AssertionError("dangerous MCP config reached connect/spawn path") + + monkeypatch.setattr("tools.mcp_tool._connect_server", _connect_should_not_run) + + with pytest.raises(ValueError, match="network egress"): + _probe_single_server("evil", _dangerous_entry(), connect_timeout=1) + + assert connected is False + + def test_runtime_loader_skips_dangerous_entry(monkeypatch): from tools.mcp_tool import _load_mcp_config @@ -74,6 +120,53 @@ def test_runtime_loader_skips_dangerous_entry(monkeypatch): assert loaded["clean"]["command"] == "npx" +def test_explicit_registration_skips_dangerous_entry_before_connect(monkeypatch): + import tools.mcp_tool as mcp_tool + + monkeypatch.setattr(mcp_tool, "_MCP_AVAILABLE", True) + monkeypatch.setattr(mcp_tool, "_ensure_mcp_loop", lambda: None) + + connected = [] + + async def _discover_one(name, config): + connected.append(name) + return [] + + def _run_on_loop(coro_or_factory, timeout=30): + import asyncio + import inspect + coro = coro_or_factory() if callable(coro_or_factory) else coro_or_factory + assert inspect.iscoroutine(coro) + return asyncio.run(coro) + + monkeypatch.setattr(mcp_tool, "_discover_and_register_server", _discover_one) + monkeypatch.setattr(mcp_tool, "_run_on_mcp_loop", _run_on_loop) + + with mcp_tool._lock: + saved_servers = dict(mcp_tool._servers) + saved_connecting = set(mcp_tool._server_connecting) + saved_errors = dict(mcp_tool._server_connect_errors) + mcp_tool._servers.clear() + mcp_tool._server_connecting.clear() + mcp_tool._server_connect_errors.clear() + + try: + mcp_tool.register_mcp_servers({ + "evil": _dangerous_entry(), + "clean": {"command": "npx", "args": ["-y", "clean-mcp"]}, + }) + finally: + with mcp_tool._lock: + mcp_tool._servers.clear() + mcp_tool._servers.update(saved_servers) + mcp_tool._server_connecting.clear() + mcp_tool._server_connecting.update(saved_connecting) + mcp_tool._server_connect_errors.clear() + mcp_tool._server_connect_errors.update(saved_errors) + + assert connected == ["clean"] + + def test_migration_disables_existing_dangerous_entry(tmp_path): import yaml diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index c619a600360..8b9200db9e4 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -89,6 +89,7 @@ import shutil import sys import threading import time +from typing import Callable from datetime import datetime from typing import Any, Coroutine, Dict, List, Optional from urllib.parse import urlparse @@ -2673,6 +2674,33 @@ def _interpolate_env_vars(value): return value +def _filter_suspicious_mcp_servers(servers: Dict[str, dict]) -> Dict[str, dict]: + """Drop exfiltration-shaped MCP configs before any stdio spawn path.""" + try: + from hermes_cli.mcp_security import validate_mcp_server_entry as _validate_mcp_server_entry + except Exception: + _validate_mcp_server_entry: Callable[[str, dict[str, Any]], list[str]] | None = None + + if _validate_mcp_server_entry is None: + return servers + + safe_servers = {} + for name, cfg in servers.items(): + if not isinstance(cfg, dict): + safe_servers[name] = cfg + continue + issues = _validate_mcp_server_entry(name, cfg) + if issues: + logger.warning( + "Skipping suspicious MCP server '%s': %s", + name, + "; ".join(issues), + ) + continue + safe_servers[name] = cfg + return safe_servers + + def _load_mcp_config() -> Dict[str, dict]: """Read ``mcp_servers`` from the Hermes config file. @@ -2695,31 +2723,17 @@ def _load_mcp_config() -> Dict[str, dict]: servers = config.get("mcp_servers") if not servers or not isinstance(servers, dict): return {} - try: - from hermes_cli.mcp_security import validate_mcp_server_entry - except Exception: - validate_mcp_server_entry = None # Ensure .env vars are available for interpolation try: from hermes_cli.env_loader import load_hermes_dotenv load_hermes_dotenv() except Exception: pass - safe_servers = {} - for name, cfg in servers.items(): - if not isinstance(cfg, dict): - safe_servers[name] = cfg - continue - if validate_mcp_server_entry: - issues = validate_mcp_server_entry(name, cfg) - if issues: - logger.warning( - "Skipping suspicious MCP server '%s': %s", - name, - "; ".join(issues), - ) - continue - safe_servers[name] = _interpolate_env_vars(cfg) + safe_servers: Dict[str, dict] = {} + for name, cfg in _filter_suspicious_mcp_servers(servers).items(): + interpolated = _interpolate_env_vars(cfg) + if isinstance(interpolated, dict): + safe_servers[name] = interpolated return safe_servers except Exception as exc: logger.debug("Failed to load MCP config: %s", exc) @@ -3667,6 +3681,7 @@ def register_mcp_servers(servers: Dict[str, dict]) -> List[str]: logger.debug("MCP SDK not available -- skipping explicit MCP registration") return [] + servers = _filter_suspicious_mcp_servers(servers) if not servers: logger.debug("No explicit MCP servers provided") return []