#!/usr/bin/env python3 """ MCP (Model Context Protocol) Client Support Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport, discovers their tools, and registers them into the hermes-agent tool registry so the agent can call them like any built-in tool. Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key. The ``mcp`` Python package is optional -- if not installed, this module is a no-op and logs a debug message. Example config:: mcp_servers: filesystem: command: "npx" args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] env: {} timeout: 120 # per-tool-call timeout in seconds (default: 120) connect_timeout: 60 # initial connection timeout (default: 60) github: command: "npx" args: ["-y", "@modelcontextprotocol/server-github"] env: GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..." remote_api: url: "https://my-mcp-server.example.com/mcp" headers: Authorization: "Bearer sk-..." timeout: 180 analysis: command: "npx" args: ["-y", "analysis-server"] sampling: # server-initiated LLM requests enabled: true # default: true model: "gemini-3-flash" # override model (optional) max_tokens_cap: 4096 # max tokens per request timeout: 30 # LLM call timeout (seconds) max_rpm: 10 # max requests per minute allowed_models: [] # model whitelist (empty = all) max_tool_rounds: 5 # tool loop limit (0 = disable) log_level: "info" # audit verbosity Features: - Stdio transport (command + args) and HTTP/StreamableHTTP transport (url) - Automatic reconnection with exponential backoff (up to 5 retries) - Environment variable filtering for stdio subprocesses (security) - Credential stripping in error messages returned to the LLM - Configurable per-server timeouts for tool calls and connections - Thread-safe architecture with dedicated background event loop - Sampling support: MCP servers can request LLM completions via sampling/createMessage (text and tool-use responses) Architecture: A dedicated background event loop (_mcp_loop) runs in a daemon thread. Each MCP server runs as a long-lived asyncio Task on this loop, keeping its transport context alive. Tool call coroutines are scheduled onto the loop via ``run_coroutine_threadsafe()``. On shutdown, each server Task is signalled to exit its ``async with`` block, ensuring the anyio cancel-scope cleanup happens in the *same* Task that opened the connection (required by anyio). Thread safety: _servers and _mcp_loop/_mcp_thread are accessed from both the MCP background thread and caller threads. All mutations are protected by _lock so the code is safe regardless of GIL presence (e.g. Python 3.13+ free-threading). """ import asyncio import concurrent.futures import inspect import json import logging import math import os import re import shutil import sys import threading import time from datetime import datetime from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Stdio subprocess stderr redirection # --------------------------------------------------------------------------- # # The MCP SDK's ``stdio_client(server, errlog=sys.stderr)`` defaults the # subprocess stderr stream to the parent process's real stderr, i.e. the # user's TTY. That means any MCP server we spawn at startup (FastMCP # banners, slack-mcp-server JSON startup logs, etc.) writes directly onto # the terminal while prompt_toolkit / Rich is rendering the TUI — which # corrupts the display and can hang the session. # # Instead we redirect every stdio MCP subprocess's stderr into a shared # per-profile log file (~/.hermes/logs/mcp-stderr.log), tagged with the # server name so individual servers remain debuggable. # # Fallback is os.devnull if opening the log file fails for any reason. _mcp_stderr_log_fh: Optional[Any] = None _mcp_stderr_log_lock = threading.Lock() def _get_mcp_stderr_log() -> Any: """Return a shared append-mode file handle for MCP subprocess stderr. Opened once per process and reused for every stdio server. Must have a real OS-level file descriptor (``fileno()``) because asyncio's subprocess machinery wires the child's stderr directly to that fd. Falls back to ``/dev/null`` if opening the log file fails. """ global _mcp_stderr_log_fh with _mcp_stderr_log_lock: if _mcp_stderr_log_fh is not None: return _mcp_stderr_log_fh try: from hermes_constants import get_hermes_home log_dir = get_hermes_home() / "logs" log_dir.mkdir(parents=True, exist_ok=True) log_path = log_dir / "mcp-stderr.log" # Line-buffered so server output lands on disk promptly; errors= # "replace" tolerates garbled binary output from misbehaving # servers. fh = open(log_path, "a", encoding="utf-8", errors="replace", buffering=1) # Sanity-check: confirm a real fd is available before we commit. fh.fileno() _mcp_stderr_log_fh = fh except Exception as exc: # pragma: no cover — best-effort fallback logger.debug("Failed to open MCP stderr log, using devnull: %s", exc) try: _mcp_stderr_log_fh = open(os.devnull, "w", encoding="utf-8") except Exception: # Last resort: the real stderr. Not ideal for TUI users but # it matches pre-fix behavior. _mcp_stderr_log_fh = sys.stderr return _mcp_stderr_log_fh def _write_stderr_log_header(server_name: str) -> None: """Write a human-readable session marker before launching a server. Gives operators a way to find each server's output in the shared ``mcp-stderr.log`` file without needing per-line prefixes (which would require a pipe + reader thread and complicate shutdown). """ fh = _get_mcp_stderr_log() try: ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fh.write(f"\n===== [{ts}] starting MCP server '{server_name}' =====\n") fh.flush() except Exception: pass # --------------------------------------------------------------------------- # Graceful import -- MCP SDK is an optional dependency # --------------------------------------------------------------------------- _MCP_AVAILABLE = False _MCP_HTTP_AVAILABLE = False _MCP_SAMPLING_TYPES = False _MCP_NOTIFICATION_TYPES = False _MCP_MESSAGE_HANDLER_SUPPORTED = False # Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION. # Streamable HTTP was introduced by 2025-03-26, so this remains valid for the # HTTP transport path even on older-but-supported SDK versions. LATEST_PROTOCOL_VERSION = "2025-03-26" try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client _MCP_AVAILABLE = True try: from mcp.client.streamable_http import streamablehttp_client _MCP_HTTP_AVAILABLE = True except ImportError: _MCP_HTTP_AVAILABLE = False # Prefer the non-deprecated API (mcp >= 1.24.0); fall back to the # deprecated wrapper for older SDK versions. try: from mcp.client.streamable_http import streamable_http_client _MCP_NEW_HTTP = True except ImportError: _MCP_NEW_HTTP = False try: from mcp.types import LATEST_PROTOCOL_VERSION except ImportError: logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version") # Sampling types -- separated so older SDK versions don't break MCP support try: from mcp.types import ( CreateMessageResult, CreateMessageResultWithTools, ErrorData, SamplingCapability, SamplingToolsCapability, TextContent, ToolUseContent, ) _MCP_SAMPLING_TYPES = True except ImportError: logger.debug("MCP sampling types not available -- sampling disabled") # Notification types for dynamic tool discovery (tools/list_changed) try: from mcp.types import ( ServerNotification, ToolListChangedNotification, PromptListChangedNotification, ResourceListChangedNotification, ) _MCP_NOTIFICATION_TYPES = True except ImportError: logger.debug("MCP notification types not available -- dynamic tool discovery disabled") except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") def _check_message_handler_support() -> bool: """Check if ClientSession accepts ``message_handler`` kwarg. Inspects the constructor signature for backward compatibility with older MCP SDK versions that don't support notification handlers. """ if not _MCP_AVAILABLE: return False try: return "message_handler" in inspect.signature(ClientSession).parameters except (TypeError, ValueError): return False _MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support() if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED: logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled") # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- _DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls _DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server _MAX_RECONNECT_RETRIES = 5 _MAX_INITIAL_CONNECT_RETRIES = 3 # retries for the very first connection attempt _MAX_BACKOFF_SECONDS = 60 # Environment variables that are safe to pass to stdio subprocesses _SAFE_ENV_KEYS = frozenset({ "PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR", }) # Regex for credential patterns to strip from error messages _CREDENTIAL_PATTERN = re.compile( r"(?:" r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key r"|Bearer\s+\S+" # Bearer token r"|token=[^\s&,;\"']{1,255}" # token=... r"|key=[^\s&,;\"']{1,255}" # key=... r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=... r"|password=[^\s&,;\"']{1,255}" # password=... r"|secret=[^\s&,;\"']{1,255}" # secret=... r")", re.IGNORECASE, ) # --------------------------------------------------------------------------- # Security helpers # --------------------------------------------------------------------------- def _build_safe_env(user_env: Optional[dict]) -> dict: """Build a filtered environment dict for stdio subprocesses. Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_* variables from the current process environment, plus any variables explicitly specified by the user in the server config. This prevents accidentally leaking secrets like API keys, tokens, or credentials to MCP server subprocesses. """ env = {} for key, value in os.environ.items(): if key in _SAFE_ENV_KEYS or key.startswith("XDG_"): env[key] = value if user_env: env.update(user_env) return env def _sanitize_error(text: str) -> str: """Strip credential-like patterns from error text before returning to LLM. Replaces tokens, keys, and other secrets with [REDACTED] to prevent accidental credential exposure in tool error responses. """ return _CREDENTIAL_PATTERN.sub("[REDACTED]", text) # --------------------------------------------------------------------------- # MCP tool description content scanning # --------------------------------------------------------------------------- # Patterns that indicate potential prompt injection in MCP tool descriptions. # These are WARNING-level — we log but don't block, since false positives # would break legitimate MCP servers. _MCP_INJECTION_PATTERNS = [ (re.compile(r"ignore\s+(all\s+)?previous\s+instructions", re.I), "prompt override attempt ('ignore previous instructions')"), (re.compile(r"you\s+are\s+now\s+a", re.I), "identity override attempt ('you are now a...')"), (re.compile(r"your\s+new\s+(task|role|instructions?)\s+(is|are)", re.I), "task override attempt"), (re.compile(r"system\s*:\s*", re.I), "system prompt injection attempt"), (re.compile(r"<\s*(system|human|assistant)\s*>", re.I), "role tag injection attempt"), (re.compile(r"do\s+not\s+(tell|inform|mention|reveal)", re.I), "concealment instruction"), (re.compile(r"(curl|wget|fetch)\s+https?://", re.I), "network command in description"), (re.compile(r"base64\.(b64decode|decodebytes)", re.I), "base64 decode reference"), (re.compile(r"exec\s*\(|eval\s*\(", re.I), "code execution reference"), (re.compile(r"import\s+(subprocess|os|shutil|socket)", re.I), "dangerous import reference"), ] def _scan_mcp_description(server_name: str, tool_name: str, description: str) -> List[str]: """Scan an MCP tool description for prompt injection patterns. Returns a list of finding strings (empty = clean). """ findings = [] if not description: return findings for pattern, reason in _MCP_INJECTION_PATTERNS: if pattern.search(description): findings.append(reason) if findings: logger.warning( "MCP server '%s' tool '%s': suspicious description content — %s. " "Description: %.200s", server_name, tool_name, "; ".join(findings), description, ) return findings def _prepend_path(env: dict, directory: str) -> dict: """Prepend *directory* to env PATH if it is not already present.""" updated = dict(env or {}) if not directory: return updated existing = updated.get("PATH", "") parts = [part for part in existing.split(os.pathsep) if part] if directory not in parts: parts = [directory, *parts] updated["PATH"] = os.pathsep.join(parts) if parts else directory return updated def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]: """Resolve a stdio MCP command against the exact subprocess environment. This primarily exists to make bare ``npx``/``npm``/``node`` commands work reliably even when MCP subprocesses run under a filtered PATH. """ resolved_command = os.path.expanduser(str(command).strip()) resolved_env = dict(env or {}) if os.sep not in resolved_command: path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None which_hit = shutil.which(resolved_command, path=path_arg) if which_hit: resolved_command = which_hit elif resolved_command in {"npx", "npm", "node"}: hermes_home = os.path.expanduser( os.getenv( "HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes") ) ) candidates = [ os.path.join(hermes_home, "node", "bin", resolved_command), os.path.join(os.path.expanduser("~"), ".local", "bin", resolved_command), ] for candidate in candidates: if os.path.isfile(candidate) and os.access(candidate, os.X_OK): resolved_command = candidate break command_dir = os.path.dirname(resolved_command) if command_dir: resolved_env = _prepend_path(resolved_env, command_dir) return resolved_command, resolved_env def _format_connect_error(exc: BaseException) -> str: """Render nested MCP connection errors into an actionable short message.""" def _find_missing(current: BaseException) -> Optional[str]: nested = getattr(current, "exceptions", None) if nested: for child in nested: missing = _find_missing(child) if missing: return missing return None if isinstance(current, FileNotFoundError): if getattr(current, "filename", None): return str(current.filename) match = re.search(r"No such file or directory: '([^']+)'", str(current)) if match: return match.group(1) for attr in ("__cause__", "__context__"): nested_exc = getattr(current, attr, None) if isinstance(nested_exc, BaseException): missing = _find_missing(nested_exc) if missing: return missing return None def _flatten_messages(current: BaseException) -> List[str]: nested = getattr(current, "exceptions", None) if nested: flattened: List[str] = [] for child in nested: flattened.extend(_flatten_messages(child)) return flattened messages = [] text = str(current).strip() if text: messages.append(text) for attr in ("__cause__", "__context__"): nested_exc = getattr(current, attr, None) if isinstance(nested_exc, BaseException): messages.extend(_flatten_messages(nested_exc)) return messages or [current.__class__.__name__] missing = _find_missing(exc) if missing: message = f"missing executable '{missing}'" if os.path.basename(missing) in {"npx", "npm", "node"}: message += ( " (ensure Node.js is installed and PATH includes its bin directory, " "or set mcp_servers..command to an absolute path and include " "that directory in mcp_servers..env.PATH)" ) return _sanitize_error(message) deduped: List[str] = [] for item in _flatten_messages(exc): if item not in deduped: deduped.append(item) return _sanitize_error("; ".join(deduped[:3])) # --------------------------------------------------------------------------- # Sampling -- server-initiated LLM requests (MCP sampling/createMessage) # --------------------------------------------------------------------------- def _safe_numeric(value, default, coerce=int, minimum=1): """Coerce a config value to a numeric type, returning *default* on failure. Handles string values from YAML (e.g. ``"10"`` instead of ``10``), non-finite floats, and values below *minimum*. """ try: result = coerce(value) if isinstance(result, float) and not math.isfinite(result): return default return max(result, minimum) except (TypeError, ValueError, OverflowError): return default class SamplingHandler: """Handles sampling/createMessage requests for a single MCP server. Each MCPServerTask that has sampling enabled creates one SamplingHandler. The handler is callable and passed directly to ``ClientSession`` as the ``sampling_callback``. All state (rate-limit timestamps, metrics, tool-loop counters) lives on the instance -- no module-level globals. The callback is async and runs on the MCP background event loop. The sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so it doesn't block the event loop. """ _STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"} def __init__(self, server_name: str, config: dict): self.server_name = server_name self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int) self.timeout = _safe_numeric(config.get("timeout", 30), 30, float) self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int) self.max_tool_rounds = _safe_numeric( config.get("max_tool_rounds", 5), 5, int, minimum=0, ) self.model_override = config.get("model") self.allowed_models = config.get("allowed_models", []) _log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING} self.audit_level = _log_levels.get( str(config.get("log_level", "info")).lower(), logging.INFO, ) # Per-instance state self._rate_timestamps: List[float] = [] self._tool_loop_count = 0 self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0} # -- Rate limiting ------------------------------------------------------- def _check_rate_limit(self) -> bool: """Sliding-window rate limiter. Returns True if request is allowed.""" now = time.time() window = now - 60 self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window] if len(self._rate_timestamps) >= self.max_rpm: return False self._rate_timestamps.append(now) return True # -- Model resolution ---------------------------------------------------- def _resolve_model(self, preferences) -> Optional[str]: """Config override > server hint > None (use default).""" if self.model_override: return self.model_override if preferences and hasattr(preferences, "hints") and preferences.hints: for hint in preferences.hints: if hasattr(hint, "name") and hint.name: return hint.name return None # -- Message conversion -------------------------------------------------- @staticmethod def _extract_tool_result_text(block) -> str: """Extract text from a ToolResultContent block.""" if not hasattr(block, "content") or block.content is None: return "" items = block.content if isinstance(block.content, list) else [block.content] return "\n".join(item.text for item in items if hasattr(item, "text")) def _convert_messages(self, params) -> List[dict]: """Convert MCP SamplingMessages to OpenAI format. Uses ``msg.content_as_list`` (SDK helper) so single-block and list-of-blocks are handled uniformly. Dispatches per block type with ``isinstance`` on real SDK types when available, falling back to duck-typing via ``hasattr`` for compatibility. """ messages: List[dict] = [] for msg in params.messages: blocks = msg.content_as_list if hasattr(msg, "content_as_list") else ( msg.content if isinstance(msg.content, list) else [msg.content] ) # Separate blocks by kind tool_results = [b for b in blocks if hasattr(b, "toolUseId")] tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")] content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))] # Emit tool result messages (role: tool) for tr in tool_results: messages.append({ "role": "tool", "tool_call_id": tr.toolUseId, "content": self._extract_tool_result_text(tr), }) # Emit assistant tool_calls message if tool_uses: tc_list = [] for tu in tool_uses: tc_list.append({ "id": getattr(tu, "id", f"call_{len(tc_list)}"), "type": "function", "function": { "name": tu.name, "arguments": json.dumps(tu.input, ensure_ascii=False) if isinstance(tu.input, dict) else str(tu.input), }, }) msg_dict: dict = {"role": msg.role, "tool_calls": tc_list} # Include any accompanying text text_parts = [b.text for b in content_blocks if hasattr(b, "text")] if text_parts: msg_dict["content"] = "\n".join(text_parts) messages.append(msg_dict) elif content_blocks: # Pure text/image content if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"): messages.append({"role": msg.role, "content": content_blocks[0].text}) else: parts = [] for block in content_blocks: if hasattr(block, "text"): parts.append({"type": "text", "text": block.text}) elif hasattr(block, "data") and hasattr(block, "mimeType"): parts.append({ "type": "image_url", "image_url": {"url": f"data:{block.mimeType};base64,{block.data}"}, }) else: logger.warning( "Unsupported sampling content block type: %s (skipped)", type(block).__name__, ) if parts: messages.append({"role": msg.role, "content": parts}) return messages # -- Error helper -------------------------------------------------------- @staticmethod def _error(message: str, code: int = -1): """Return ErrorData (MCP spec) or raise as fallback.""" if _MCP_SAMPLING_TYPES: return ErrorData(code=code, message=message) raise Exception(message) # -- Response building --------------------------------------------------- def _build_tool_use_result(self, choice, response): """Build a CreateMessageResultWithTools from an LLM tool_calls response.""" self.metrics["tool_use_count"] += 1 # Tool loop governance if self.max_tool_rounds == 0: self._tool_loop_count = 0 return self._error( f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)" ) self._tool_loop_count += 1 if self._tool_loop_count > self.max_tool_rounds: self._tool_loop_count = 0 return self._error( f"Tool loop limit exceeded for server '{self.server_name}' " f"(max {self.max_tool_rounds} rounds)" ) content_blocks = [] for tc in choice.message.tool_calls: args = tc.function.arguments if isinstance(args, str): try: parsed = json.loads(args) except (json.JSONDecodeError, ValueError): logger.warning( "MCP server '%s': malformed tool_calls arguments " "from LLM (wrapping as raw): %.100s", self.server_name, args, ) parsed = {"_raw": args} else: parsed = args if isinstance(args, dict) else {"_raw": str(args)} content_blocks.append(ToolUseContent( type="tool_use", id=tc.id, name=tc.function.name, input=parsed, )) logger.log( self.audit_level, "MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d", self.server_name, response.model, getattr(getattr(response, "usage", None), "total_tokens", "?"), len(content_blocks), ) return CreateMessageResultWithTools( role="assistant", content=content_blocks, model=response.model, stopReason="toolUse", ) def _build_text_result(self, choice, response): """Build a CreateMessageResult from a normal text response.""" self._tool_loop_count = 0 # reset on text response response_text = choice.message.content or "" logger.log( self.audit_level, "MCP server '%s' sampling response: model=%s, tokens=%s", self.server_name, response.model, getattr(getattr(response, "usage", None), "total_tokens", "?"), ) return CreateMessageResult( role="assistant", content=TextContent(type="text", text=_sanitize_error(response_text)), model=response.model, stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"), ) # -- Session kwargs helper ----------------------------------------------- def session_kwargs(self) -> dict: """Return kwargs to pass to ClientSession for sampling support.""" return { "sampling_callback": self, "sampling_capabilities": SamplingCapability( tools=SamplingToolsCapability(), ), } # -- Main callback ------------------------------------------------------- async def __call__(self, context, params): """Sampling callback invoked by the MCP SDK. Conforms to ``SamplingFnT`` protocol. Returns ``CreateMessageResult``, ``CreateMessageResultWithTools``, or ``ErrorData``. """ # Rate limit if not self._check_rate_limit(): logger.warning( "MCP server '%s' sampling rate limit exceeded (%d/min)", self.server_name, self.max_rpm, ) self.metrics["errors"] += 1 return self._error( f"Sampling rate limit exceeded for server '{self.server_name}' " f"({self.max_rpm} requests/minute)" ) # Resolve model model = self._resolve_model(getattr(params, "modelPreferences", None)) # Get auxiliary LLM client via centralized router from agent.auxiliary_client import call_llm # Model whitelist check (we need to resolve model before calling) resolved_model = model or self.model_override or "" if self.allowed_models and resolved_model and resolved_model not in self.allowed_models: logger.warning( "MCP server '%s' requested model '%s' not in allowed_models", self.server_name, resolved_model, ) self.metrics["errors"] += 1 return self._error( f"Model '{resolved_model}' not allowed for server " f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}" ) # Convert messages messages = self._convert_messages(params) if hasattr(params, "systemPrompt") and params.systemPrompt: messages.insert(0, {"role": "system", "content": params.systemPrompt}) # Build LLM call kwargs max_tokens = min(params.maxTokens, self.max_tokens_cap) call_temperature = None if hasattr(params, "temperature") and params.temperature is not None: call_temperature = params.temperature # Forward server-provided tools call_tools = None server_tools = getattr(params, "tools", None) if server_tools: call_tools = [ { "type": "function", "function": { "name": getattr(t, "name", ""), "description": getattr(t, "description", "") or "", "parameters": _normalize_mcp_input_schema( getattr(t, "inputSchema", None) ), }, } for t in server_tools ] logger.log( self.audit_level, "MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d", self.server_name, resolved_model, max_tokens, len(messages), ) # Offload sync LLM call to thread (non-blocking) def _sync_call(): return call_llm( task="mcp", model=resolved_model or None, messages=messages, temperature=call_temperature, max_tokens=max_tokens, tools=call_tools, timeout=self.timeout, ) try: response = await asyncio.wait_for( asyncio.to_thread(_sync_call), timeout=self.timeout, ) except asyncio.TimeoutError: self.metrics["errors"] += 1 return self._error( f"Sampling LLM call timed out after {self.timeout}s " f"for server '{self.server_name}'" ) except Exception as exc: self.metrics["errors"] += 1 return self._error( f"Sampling LLM call failed: {_sanitize_error(str(exc))}" ) # Guard against empty choices (content filtering, provider errors) if not getattr(response, "choices", None): self.metrics["errors"] += 1 return self._error( f"LLM returned empty response (no choices) for server " f"'{self.server_name}'" ) # Track metrics choice = response.choices[0] self.metrics["requests"] += 1 total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0) if isinstance(total_tokens, int): self.metrics["tokens_used"] += total_tokens # Dispatch based on response type if ( choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls ): return self._build_tool_use_result(choice, response) return self._build_text_result(choice, response) # --------------------------------------------------------------------------- # Server task -- each MCP server lives in one long-lived asyncio Task # --------------------------------------------------------------------------- class MCPServerTask: """Manages a single MCP server connection in a dedicated asyncio Task. The entire connection lifecycle (connect, discover, serve, disconnect) runs inside one asyncio Task so that anyio cancel-scopes created by the transport client are entered and exited in the same Task context. Supports both stdio and HTTP/StreamableHTTP transports. """ __slots__ = ( "name", "session", "tool_timeout", "_task", "_ready", "_shutdown_event", "_reconnect_event", "_tools", "_error", "_config", "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", ) def __init__(self, name: str): self.name = name self.session: Optional[Any] = None self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT self._task: Optional[asyncio.Task] = None self._ready = asyncio.Event() self._shutdown_event = asyncio.Event() # Set by tool handlers on auth failure after manager.handle_401() # confirms recovery is viable. When set, _run_http / _run_stdio # exit their async-with blocks cleanly (no exception), and the # outer run() loop re-enters the transport so the MCP session is # rebuilt with fresh credentials. self._reconnect_event = asyncio.Event() self._tools: list = [] self._error: Optional[Exception] = None self._config: dict = {} self._sampling: Optional[SamplingHandler] = None self._registered_tool_names: list[str] = [] self._auth_type: str = "" self._refresh_lock = asyncio.Lock() def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" return "url" in self._config # ----- Dynamic tool discovery (notifications/tools/list_changed) ----- def _make_message_handler(self): """Build a ``message_handler`` callback for ``ClientSession``. Dispatches on notification type. Only ``ToolListChangedNotification`` triggers a refresh; prompt and resource change notifications are logged as stubs for future work. """ async def _handler(message): try: if isinstance(message, Exception): logger.debug("MCP message handler (%s): exception: %s", self.name, message) return if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification): match message.root: case ToolListChangedNotification(): logger.info( "MCP server '%s': received tools/list_changed notification", self.name, ) await self._refresh_tools() case PromptListChangedNotification(): logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name) case ResourceListChangedNotification(): logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name) case _: pass except Exception: logger.exception("Error in MCP message handler for '%s'", self.name) return _handler async def _refresh_tools(self): """Re-fetch tools from the server and update the registry. Called when the server sends ``notifications/tools/list_changed``. The lock prevents overlapping refreshes from rapid-fire notifications. After the initial ``await`` (list_tools), all mutations are synchronous — atomic from the event loop's perspective. """ from tools.registry import registry async with self._refresh_lock: # Capture old tool names for change diff old_tool_names = set(self._registered_tool_names) # 1. Fetch current tool list from server tools_result = await self.session.list_tools() new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else [] # 2. Deregister old tools from the central registry for prefixed_name in self._registered_tool_names: registry.deregister(prefixed_name) # 3. Re-register with fresh tool list self._tools = new_mcp_tools self._registered_tool_names = _register_server_tools( self.name, self, self._config ) # 5. Log what changed (user-visible notification) new_tool_names = set(self._registered_tool_names) added = new_tool_names - old_tool_names removed = old_tool_names - new_tool_names changes = [] if added: changes.append(f"added: {', '.join(sorted(added))}") if removed: changes.append(f"removed: {', '.join(sorted(removed))}") if changes: logger.warning( "MCP server '%s': tools changed dynamically — %s. " "Verify these changes are expected.", self.name, "; ".join(changes), ) else: logger.info( "MCP server '%s': dynamically refreshed %d tool(s) (no changes)", self.name, len(self._registered_tool_names), ) async def _wait_for_lifecycle_event(self) -> str: """Block until either _shutdown_event or _reconnect_event fires. Returns: "shutdown" if the server should exit the run loop entirely. "reconnect" if the server should tear down the current MCP session and re-enter the transport (fresh OAuth tokens, new session ID, etc.). The reconnect event is cleared before return so the next cycle starts with a fresh signal. Shutdown takes precedence if both events are set simultaneously. """ shutdown_task = asyncio.create_task(self._shutdown_event.wait()) reconnect_task = asyncio.create_task(self._reconnect_event.wait()) try: await asyncio.wait( {shutdown_task, reconnect_task}, return_when=asyncio.FIRST_COMPLETED, ) finally: for t in (shutdown_task, reconnect_task): if not t.done(): t.cancel() try: await t except (asyncio.CancelledError, Exception): pass if self._shutdown_event.is_set(): return "shutdown" self._reconnect_event.clear() return "reconnect" async def _run_stdio(self, config: dict): """Run the server using stdio transport.""" command = config.get("command") args = config.get("args", []) user_env = config.get("env") if not command: raise ValueError( f"MCP server '{self.name}' has no 'command' in config" ) safe_env = _build_safe_env(user_env) command, safe_env = _resolve_stdio_command(command, safe_env) # Check package against OSV malware database before spawning from tools.osv_check import check_package_for_malware malware_error = check_package_for_malware(command, args) if malware_error: raise ValueError( f"MCP server '{self.name}': {malware_error}" ) server_params = StdioServerParameters( command=command, args=args, env=safe_env if safe_env else None, ) sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: sampling_kwargs["message_handler"] = self._make_message_handler() # Snapshot child PIDs before spawning so we can track the new one. pids_before = _snapshot_child_pids() # Redirect subprocess stderr into a shared log file so MCP servers # (FastMCP banners, slack-mcp startup JSON, etc.) don't dump onto # the user's TTY and corrupt the TUI. Preserves debuggability via # ~/.hermes/logs/mcp-stderr.log. _write_stderr_log_header(self.name) _errlog = _get_mcp_stderr_log() async with stdio_client(server_params, errlog=_errlog) as (read_stream, write_stream): # Capture the newly spawned subprocess PID for force-kill cleanup. new_pids = _snapshot_child_pids() - pids_before if new_pids: with _lock: for _pid in new_pids: _stdio_pids[_pid] = self.name async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session await self._discover_tools() self._ready.set() # stdio transport does not use OAuth, but we still honor # _reconnect_event (e.g. future manual /mcp refresh) for # consistency with _run_http. await self._wait_for_lifecycle_event() # Context exited cleanly — subprocess was terminated by the SDK. if new_pids: with _lock: for _pid in new_pids: _stdio_pids.pop(_pid, None) async def _run_http(self, config: dict): """Run the server using HTTP/StreamableHTTP transport.""" if not _MCP_HTTP_AVAILABLE: raise ImportError( f"MCP server '{self.name}' requires HTTP transport but " "mcp.client.streamable_http is not available. " "Upgrade the mcp package to get HTTP support." ) url = config["url"] headers = dict(config.get("headers") or {}) # Some MCP servers require MCP-Protocol-Version on the initial # initialize request and reject session-less POSTs otherwise. # Seed it as a client-level default, but treat user overrides as # case-insensitive so conventional casing is preserved. if not any(key.lower() == "mcp-protocol-version" for key in headers): headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) ssl_verify = config.get("ssl_verify", True) # OAuth 2.1 PKCE: route through the central MCPOAuthManager so the # same provider instance is reused across reconnects, pre-flow # disk-watch is active, and config-time CLI code paths share state. # If OAuth setup fails (e.g. non-interactive env without cached # tokens), re-raise so this server is reported as failed without # blocking other MCP servers from connecting. _oauth_auth = None if self._auth_type == "oauth": try: from tools.mcp_oauth_manager import get_manager _oauth_auth = get_manager().get_or_build_provider( self.name, url, config.get("oauth"), ) except Exception as exc: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) raise sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: sampling_kwargs["message_handler"] = self._make_message_handler() if _MCP_NEW_HTTP: # New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient # matching the SDK's own create_mcp_http_client defaults. import httpx client_kwargs: dict = { "follow_redirects": True, "timeout": httpx.Timeout(float(connect_timeout), read=300.0), "verify": ssl_verify, } if headers: client_kwargs["headers"] = headers if _oauth_auth is not None: client_kwargs["auth"] = _oauth_auth # Caller owns the client lifecycle — the SDK skips cleanup when # http_client is provided, so we wrap in async-with. async with httpx.AsyncClient(**client_kwargs) as http_client: async with streamable_http_client(url, http_client=http_client) as ( read_stream, write_stream, _get_session_id, ): async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session await self._discover_tools() self._ready.set() reason = await self._wait_for_lifecycle_event() if reason == "reconnect": logger.info( "MCP server '%s': reconnect requested — " "tearing down HTTP session", self.name, ) else: # Deprecated API (mcp < 1.24.0): manages httpx client internally. _http_kwargs: dict = { "headers": headers, "timeout": float(connect_timeout), "verify": ssl_verify, } if _oauth_auth is not None: _http_kwargs["auth"] = _oauth_auth async with streamablehttp_client(url, **_http_kwargs) as ( read_stream, write_stream, _get_session_id, ): async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() self.session = session await self._discover_tools() self._ready.set() reason = await self._wait_for_lifecycle_event() if reason == "reconnect": logger.info( "MCP server '%s': reconnect requested — " "tearing down legacy HTTP session", self.name, ) async def _discover_tools(self): """Discover tools from the connected session.""" if self.session is None: return tools_result = await self.session.list_tools() self._tools = ( tools_result.tools if hasattr(tools_result, "tools") else [] ) async def run(self, config: dict): """Long-lived coroutine: connect, discover tools, wait, disconnect. Includes automatic reconnection with exponential backoff if the connection drops unexpectedly (unless shutdown was requested). """ self._config = config self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT) self._auth_type = (config.get("auth") or "").lower().strip() # Set up sampling handler if enabled and SDK types are available sampling_config = config.get("sampling", {}) if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES: self._sampling = SamplingHandler(self.name, sampling_config) else: self._sampling = None # Validate: warn if both url and command are present if "url" in config and "command" in config: logger.warning( "MCP server '%s' has both 'url' and 'command' in config. " "Using HTTP transport ('url'). Remove 'command' to silence " "this warning.", self.name, ) retries = 0 initial_retries = 0 backoff = 1.0 while True: try: if self._is_http(): await self._run_http(config) else: await self._run_stdio(config) # Transport returned cleanly. Two cases: # - _shutdown_event was set: exit the run loop entirely. # - _reconnect_event was set (auth recovery): loop back and # rebuild the MCP session with fresh credentials. Do NOT # touch the retry counters — this is not a failure. if self._shutdown_event.is_set(): break logger.info( "MCP server '%s': reconnecting (OAuth recovery or " "manual refresh)", self.name, ) # Reset the session reference; _run_http/_run_stdio will # repopulate it on successful re-entry. self.session = None # Keep _ready set across reconnects so tool handlers can # still detect a transient in-flight state — it'll be # re-set after the fresh session initializes. continue except Exception as exc: self.session = None # If this is the first connection attempt, retry with backoff # before giving up. A transient DNS/network blip at startup # should not permanently kill the server. # (Ported from Kilo Code's MCP resilience fix.) if not self._ready.is_set(): initial_retries += 1 if initial_retries > _MAX_INITIAL_CONNECT_RETRIES: logger.warning( "MCP server '%s' failed initial connection after " "%d attempts, giving up: %s", self.name, _MAX_INITIAL_CONNECT_RETRIES, exc, ) self._error = exc self._ready.set() return logger.warning( "MCP server '%s' initial connection failed " "(attempt %d/%d), retrying in %.0fs: %s", self.name, initial_retries, _MAX_INITIAL_CONNECT_RETRIES, backoff, exc, ) await asyncio.sleep(backoff) backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS) # Check if shutdown was requested during the sleep if self._shutdown_event.is_set(): self._error = exc self._ready.set() return continue # If shutdown was requested, don't reconnect if self._shutdown_event.is_set(): logger.debug( "MCP server '%s' disconnected during shutdown: %s", self.name, exc, ) return retries += 1 if retries > _MAX_RECONNECT_RETRIES: logger.warning( "MCP server '%s' failed after %d reconnection attempts, " "giving up: %s", self.name, _MAX_RECONNECT_RETRIES, exc, ) return logger.warning( "MCP server '%s' connection lost (attempt %d/%d), " "reconnecting in %.0fs: %s", self.name, retries, _MAX_RECONNECT_RETRIES, backoff, exc, ) await asyncio.sleep(backoff) backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS) # Check again after sleeping if self._shutdown_event.is_set(): return finally: self.session = None async def start(self, config: dict): """Create the background Task and wait until ready (or failed).""" self._task = asyncio.ensure_future(self.run(config)) await self._ready.wait() if self._error: raise self._error async def shutdown(self): """Signal the Task to exit and wait for clean resource teardown.""" from tools.registry import registry self._shutdown_event.set() # Defensive: if _wait_for_lifecycle_event is blocking, we need ANY # event to unblock it. _shutdown_event alone is sufficient (the # helper checks shutdown first), but setting reconnect too ensures # there's no race where the helper misses the shutdown flag after # returning "reconnect". self._reconnect_event.set() if self._task and not self._task.done(): try: await asyncio.wait_for(self._task, timeout=10) except asyncio.TimeoutError: logger.warning( "MCP server '%s' shutdown timed out, cancelling task", self.name, ) self._task.cancel() try: await self._task except asyncio.CancelledError: pass for tool_name in list(getattr(self, "_registered_tool_names", [])): registry.deregister(tool_name) self._registered_tool_names = [] self.session = None # --------------------------------------------------------------------------- # Module-level state # --------------------------------------------------------------------------- _servers: Dict[str, MCPServerTask] = {} # Circuit breaker: consecutive error counts per server. After # _CIRCUIT_BREAKER_THRESHOLD consecutive failures, the handler returns # a "server unreachable" message that tells the model to stop retrying, # preventing the 90-iteration burn loop described in #10447. # # State machine: # closed — error count below threshold; all calls go through. # open — threshold reached; calls short-circuit until the # cooldown elapses. # half-open — cooldown elapsed; the next call is a probe that # actually hits the session. Probe success → closed. # Probe failure → reopens (cooldown re-armed). # # ``_server_breaker_opened_at`` records the monotonic timestamp when # the breaker most recently transitioned into the open state. Use the # ``_bump_server_error`` / ``_reset_server_error`` helpers to mutate # this state — they keep the count and timestamp in sync. _server_error_counts: Dict[str, int] = {} _server_breaker_opened_at: Dict[str, float] = {} _CIRCUIT_BREAKER_THRESHOLD = 3 _CIRCUIT_BREAKER_COOLDOWN_SEC = 60.0 def _bump_server_error(server_name: str) -> None: """Increment the consecutive-failure count for ``server_name``. When the count crosses :data:`_CIRCUIT_BREAKER_THRESHOLD`, stamp the breaker-open timestamp so the cooldown clock starts (or re-starts, for probe failures in the half-open state). """ n = _server_error_counts.get(server_name, 0) + 1 _server_error_counts[server_name] = n if n >= _CIRCUIT_BREAKER_THRESHOLD: _server_breaker_opened_at[server_name] = time.monotonic() def _reset_server_error(server_name: str) -> None: """Fully close the breaker for ``server_name``. Clears both the failure count and the breaker-open timestamp. Call this on any unambiguous success signal (successful tool call, successful reconnect, manual /mcp refresh). """ _server_error_counts[server_name] = 0 _server_breaker_opened_at.pop(server_name, None) # --------------------------------------------------------------------------- # Auth-failure detection helpers (Task 6 of MCP OAuth consolidation) # --------------------------------------------------------------------------- # Cached tuple of auth-related exception types. Lazy so this module # imports cleanly when the MCP SDK OAuth module is missing. _AUTH_ERROR_TYPES: tuple = () def _get_auth_error_types() -> tuple: """Return a tuple of exception types that indicate MCP OAuth failure. Cached after first call. Includes: - ``mcp.client.auth.OAuthFlowError`` / ``OAuthTokenError`` — raised by the SDK's auth flow when discovery, refresh, or full re-auth fails. - ``mcp.client.auth.UnauthorizedError`` (older MCP SDKs) — kept as an optional import for forward/backward compatibility. - ``tools.mcp_oauth.OAuthNonInteractiveError`` — raised by our callback handler when no user is present to complete a browser flow. - ``httpx.HTTPStatusError`` — caller must additionally check ``status_code == 401`` via :func:`_is_auth_error`. """ global _AUTH_ERROR_TYPES if _AUTH_ERROR_TYPES: return _AUTH_ERROR_TYPES types: list = [] try: from mcp.client.auth import OAuthFlowError, OAuthTokenError types.extend([OAuthFlowError, OAuthTokenError]) except ImportError: pass try: # Older MCP SDK variants exported this from mcp.client.auth import UnauthorizedError # type: ignore types.append(UnauthorizedError) except ImportError: pass try: from tools.mcp_oauth import OAuthNonInteractiveError types.append(OAuthNonInteractiveError) except ImportError: pass try: import httpx types.append(httpx.HTTPStatusError) except ImportError: pass _AUTH_ERROR_TYPES = tuple(types) return _AUTH_ERROR_TYPES def _is_auth_error(exc: BaseException) -> bool: """Return True if ``exc`` indicates an MCP OAuth failure. ``httpx.HTTPStatusError`` is only treated as auth-related when the response status code is 401. Other HTTP errors fall through to the generic error path in the tool handlers. """ types = _get_auth_error_types() if not types or not isinstance(exc, types): return False try: import httpx if isinstance(exc, httpx.HTTPStatusError): return getattr(exc.response, "status_code", None) == 401 except ImportError: pass return True def _handle_auth_error_and_retry( server_name: str, exc: BaseException, retry_call, op_description: str, ): """Attempt auth recovery and one retry; return None to fall through. Called by the 5 MCP tool handlers when ``session.()`` raises an auth-related exception. Workflow: 1. Ask :class:`tools.mcp_oauth_manager.MCPOAuthManager.handle_401` if recovery is viable (i.e., disk has fresh tokens, or the SDK can refresh in-place). 2. If yes, set the server's ``_reconnect_event`` so the server task tears down the current MCP session and rebuilds it with fresh credentials. Wait briefly for ``_ready`` to re-fire. 3. Retry the operation once. Return the retry result if it produced a non-error JSON payload. Otherwise return the ``needs_reauth`` error dict so the model stops hallucinating manual refresh. 4. Return None if ``exc`` is not an auth error, signalling the caller to use the generic error path. Args: server_name: Name of the MCP server that raised. exc: The exception from the failed tool call. retry_call: Zero-arg callable that re-runs the tool call, returning the same JSON string format as the handler. op_description: Human-readable name of the operation (for logs). Returns: A JSON string if auth recovery was attempted, or None to fall through to the caller's generic error path. """ if not _is_auth_error(exc): return None from tools.mcp_oauth_manager import get_manager manager = get_manager() async def _recover(): return await manager.handle_401(server_name, None) try: recovered = _run_on_mcp_loop(_recover(), timeout=10) except Exception as rec_exc: logger.warning( "MCP OAuth '%s': recovery attempt failed: %s", server_name, rec_exc, ) recovered = False if recovered: with _lock: srv = _servers.get(server_name) if srv is not None and hasattr(srv, "_reconnect_event"): loop = _mcp_loop if loop is not None and loop.is_running(): loop.call_soon_threadsafe(srv._reconnect_event.set) # Wait briefly for the session to come back ready. Bounded # so that a stuck reconnect falls through to the error # path rather than hanging the caller. deadline = time.monotonic() + 15 while time.monotonic() < deadline: if srv.session is not None and srv._ready.is_set(): break time.sleep(0.25) # A successful OAuth recovery is independent evidence that the # server is viable again, so close the circuit breaker here — # not only on retry success. Without this, a reconnect # followed by a failing retry would leave the breaker pinned # above threshold forever (the retry-exception branch below # bumps the count again). The post-reset retry still goes # through _bump_server_error on failure, so a genuinely broken # server will re-trip the breaker as normal. _reset_server_error(server_name) try: result = retry_call() try: parsed = json.loads(result) if "error" not in parsed: _reset_server_error(server_name) return result except (json.JSONDecodeError, TypeError): _reset_server_error(server_name) return result except Exception as retry_exc: logger.warning( "MCP %s/%s retry after auth recovery failed: %s", server_name, op_description, retry_exc, ) # No recovery available, or retry also failed: surface a structured # needs_reauth error. Bumps the circuit breaker so the model stops # retrying the tool. _bump_server_error(server_name) return json.dumps({ "error": ( f"MCP server '{server_name}' requires re-authentication. " f"Run `hermes mcp login {server_name}` (or delete the tokens " f"file under ~/.hermes/mcp-tokens/ and restart). Do NOT retry " f"this tool — ask the user to re-authenticate." ), "needs_reauth": True, "server": server_name, }, ensure_ascii=False) # Dedicated event loop running in a background daemon thread. _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None # Protects _mcp_loop, _mcp_thread, _servers, and _stdio_pids. _lock = threading.Lock() # PIDs of stdio MCP server subprocesses. Tracked so we can force-kill # them on shutdown if the graceful cleanup (SDK context-manager teardown) # fails or times out. PIDs are added after connection and removed on # normal server shutdown. _stdio_pids: Dict[int, str] = {} # pid -> server_name def _snapshot_child_pids() -> set: """Return a set of current child process PIDs. Uses /proc on Linux, falls back to psutil, then empty set. Used by _run_stdio to identify the subprocess spawned by stdio_client. """ my_pid = os.getpid() # Linux: read from /proc try: children_path = f"/proc/{my_pid}/task/{my_pid}/children" with open(children_path) as f: return {int(p) for p in f.read().split() if p.strip()} except (FileNotFoundError, OSError, ValueError): pass # Fallback: psutil try: import psutil return {c.pid for c in psutil.Process(my_pid).children()} except Exception: pass return set() def _mcp_loop_exception_handler(loop, context): """Suppress benign 'Event loop is closed' noise during shutdown. When the MCP event loop is stopped and closed, httpx/httpcore async transports may fire __del__ finalizers that call call_soon() on the dead loop. asyncio catches that RuntimeError and routes it here. We silence it because the connection is being torn down anyway; all other exceptions are forwarded to the default handler. """ exc = context.get("exception") if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc): return # benign shutdown race — suppress loop.default_exception_handler(context) def _ensure_mcp_loop(): """Start the background event loop thread if not already running.""" global _mcp_loop, _mcp_thread with _lock: if _mcp_loop is not None and _mcp_loop.is_running(): return _mcp_loop = asyncio.new_event_loop() _mcp_loop.set_exception_handler(_mcp_loop_exception_handler) _mcp_thread = threading.Thread( target=_mcp_loop.run_forever, name="mcp-event-loop", daemon=True, ) _mcp_thread.start() def _run_on_mcp_loop(coro, timeout: float = 30): """Schedule a coroutine on the MCP event loop and block until done. Poll in short intervals so the calling agent thread can honor user interrupts while the MCP work is still running on the background loop. """ from tools.interrupt import is_interrupted with _lock: loop = _mcp_loop if loop is None or not loop.is_running(): raise RuntimeError("MCP event loop is not running") future = asyncio.run_coroutine_threadsafe(coro, loop) deadline = None if timeout is None else time.monotonic() + timeout while True: if is_interrupted(): future.cancel() raise InterruptedError("User sent a new message") wait_timeout = 0.1 if deadline is not None: remaining = deadline - time.monotonic() if remaining <= 0: return future.result(timeout=0) wait_timeout = min(wait_timeout, remaining) try: return future.result(timeout=wait_timeout) except concurrent.futures.TimeoutError: continue def _interrupted_call_result() -> str: """Standardized JSON error for a user-interrupted MCP tool call.""" return json.dumps({ "error": "MCP call interrupted: user sent a new message" }, ensure_ascii=False) # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- def _interpolate_env_vars(value): """Recursively resolve ``${VAR}`` placeholders from ``os.environ``.""" if isinstance(value, str): def _replace(m): return os.environ.get(m.group(1), m.group(0)) return re.sub(r"\$\{([^}]+)\}", _replace, value) if isinstance(value, dict): return {k: _interpolate_env_vars(v) for k, v in value.items()} if isinstance(value, list): return [_interpolate_env_vars(v) for v in value] return value def _load_mcp_config() -> Dict[str, dict]: """Read ``mcp_servers`` from the Hermes config file. Returns a dict of ``{server_name: server_config}`` or empty dict. Server config can contain either ``command``/``args``/``env`` for stdio transport or ``url``/``headers`` for HTTP transport, plus optional ``timeout``, ``connect_timeout``, and ``auth`` overrides. ``${ENV_VAR}`` placeholders in string values are resolved from ``os.environ`` (which includes ``~/.hermes/.env`` loaded at startup). """ try: from hermes_cli.config import load_config config = load_config() servers = config.get("mcp_servers") if not servers or not isinstance(servers, dict): return {} # Ensure .env vars are available for interpolation try: from hermes_cli.env_loader import load_hermes_dotenv load_hermes_dotenv() except Exception: pass return {name: _interpolate_env_vars(cfg) for name, cfg in servers.items()} except Exception as exc: logger.debug("Failed to load MCP config: %s", exc) return {} # --------------------------------------------------------------------------- # Server connection helper # --------------------------------------------------------------------------- async def _connect_server(name: str, config: dict) -> MCPServerTask: """Create an MCPServerTask, start it, and return when ready. The server Task keeps the connection alive in the background. Call ``server.shutdown()`` (on the same event loop) to tear it down. Raises: ValueError: if required config keys are missing. ImportError: if HTTP transport is needed but not available. Exception: on connection or initialization failure. """ server = MCPServerTask(name) await server.start(config) return server # --------------------------------------------------------------------------- # Handler / check-fn factories # --------------------------------------------------------------------------- def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): """Return a sync handler that calls an MCP tool via the background loop. The handler conforms to the registry's dispatch interface: ``handler(args_dict, **kwargs) -> str`` """ def _handler(args: dict, **kwargs) -> str: # Circuit breaker: if this server has failed too many times # consecutively, short-circuit with a clear message so the model # stops retrying and uses alternative approaches (#10447). # # Once the cooldown elapses, the breaker transitions to # half-open: we let the *next* call through as a probe. On # success the success-path below resets the breaker; on # failure the error paths below bump the count again, which # re-stamps the open-time via _bump_server_error (re-arming # the cooldown). if _server_error_counts.get(server_name, 0) >= _CIRCUIT_BREAKER_THRESHOLD: opened_at = _server_breaker_opened_at.get(server_name, 0.0) age = time.monotonic() - opened_at if age < _CIRCUIT_BREAKER_COOLDOWN_SEC: remaining = max(1, int(_CIRCUIT_BREAKER_COOLDOWN_SEC - age)) return json.dumps({ "error": ( f"MCP server '{server_name}' is unreachable after " f"{_server_error_counts[server_name]} consecutive " f"failures. Auto-retry available in ~{remaining}s. " f"Do NOT retry this tool yet — use alternative " f"approaches or ask the user to check the MCP server." ) }, ensure_ascii=False) # Cooldown elapsed → fall through as a half-open probe. with _lock: server = _servers.get(server_name) if not server or not server.session: _bump_server_error(server_name) return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }, ensure_ascii=False) async def _call(): result = await server.session.call_tool(tool_name, arguments=args) # MCP CallToolResult has .content (list of content blocks) and .isError if result.isError: error_text = "" for block in (result.content or []): if hasattr(block, "text"): error_text += block.text return json.dumps({ "error": _sanitize_error( error_text or "MCP tool returned an error" ) }, ensure_ascii=False) # Collect text from content blocks parts: List[str] = [] for block in (result.content or []): if hasattr(block, "text"): parts.append(block.text) text_result = "\n".join(parts) if parts else "" # Combine content + structuredContent when both are present. # MCP spec: content is model-oriented (text), structuredContent # is machine-oriented (JSON metadata). For an AI agent, content # is the primary payload; structuredContent supplements it. structured = getattr(result, "structuredContent", None) if structured is not None: if text_result: return json.dumps({ "result": text_result, "structuredContent": structured, }, ensure_ascii=False) return json.dumps({"result": structured}, ensure_ascii=False) return json.dumps({"result": text_result}, ensure_ascii=False) def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) try: result = _call_once() # Check if the MCP tool itself returned an error try: parsed = json.loads(result) if "error" in parsed: _bump_server_error(server_name) else: _reset_server_error(server_name) # success — reset except (json.JSONDecodeError, TypeError): _reset_server_error(server_name) # non-JSON = success return result except InterruptedError: return _interrupted_call_result() except Exception as exc: # Auth-specific recovery path: consult the manager, signal # reconnect if viable, retry once. Returns None to fall # through for non-auth exceptions. recovered = _handle_auth_error_and_retry( server_name, exc, _call_once, f"tools/call {tool_name}", ) if recovered is not None: return recovered _bump_server_error(server_name) logger.error( "MCP tool %s/%s call failed: %s", server_name, tool_name, exc, ) return json.dumps({ "error": _sanitize_error( f"MCP call failed: {type(exc).__name__}: {exc}" ) }, ensure_ascii=False) return _handler def _make_list_resources_handler(server_name: str, tool_timeout: float): """Return a sync handler that lists resources from an MCP server.""" def _handler(args: dict, **kwargs) -> str: with _lock: server = _servers.get(server_name) if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }, ensure_ascii=False) async def _call(): result = await server.session.list_resources() resources = [] for r in (result.resources if hasattr(result, "resources") else []): entry = {} if hasattr(r, "uri"): entry["uri"] = str(r.uri) if hasattr(r, "name"): entry["name"] = r.name if hasattr(r, "description") and r.description: entry["description"] = r.description if hasattr(r, "mimeType") and r.mimeType: entry["mimeType"] = r.mimeType resources.append(entry) return json.dumps({"resources": resources}, ensure_ascii=False) def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) try: return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: recovered = _handle_auth_error_and_retry( server_name, exc, _call_once, "resources/list", ) if recovered is not None: return recovered logger.error( "MCP %s/list_resources failed: %s", server_name, exc, ) return json.dumps({ "error": _sanitize_error( f"MCP call failed: {type(exc).__name__}: {exc}" ) }, ensure_ascii=False) return _handler def _make_read_resource_handler(server_name: str, tool_timeout: float): """Return a sync handler that reads a resource by URI from an MCP server.""" def _handler(args: dict, **kwargs) -> str: from tools.registry import tool_error with _lock: server = _servers.get(server_name) if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }, ensure_ascii=False) uri = args.get("uri") if not uri: return tool_error("Missing required parameter 'uri'") async def _call(): result = await server.session.read_resource(uri) # read_resource returns ReadResourceResult with .contents list parts: List[str] = [] contents = result.contents if hasattr(result, "contents") else [] for block in contents: if hasattr(block, "text"): parts.append(block.text) elif hasattr(block, "blob"): parts.append(f"[binary data, {len(block.blob)} bytes]") return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False) def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) try: return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: recovered = _handle_auth_error_and_retry( server_name, exc, _call_once, "resources/read", ) if recovered is not None: return recovered logger.error( "MCP %s/read_resource failed: %s", server_name, exc, ) return json.dumps({ "error": _sanitize_error( f"MCP call failed: {type(exc).__name__}: {exc}" ) }, ensure_ascii=False) return _handler def _make_list_prompts_handler(server_name: str, tool_timeout: float): """Return a sync handler that lists prompts from an MCP server.""" def _handler(args: dict, **kwargs) -> str: with _lock: server = _servers.get(server_name) if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }, ensure_ascii=False) async def _call(): result = await server.session.list_prompts() prompts = [] for p in (result.prompts if hasattr(result, "prompts") else []): entry = {} if hasattr(p, "name"): entry["name"] = p.name if hasattr(p, "description") and p.description: entry["description"] = p.description if hasattr(p, "arguments") and p.arguments: entry["arguments"] = [ { "name": a.name, **({"description": a.description} if hasattr(a, "description") and a.description else {}), **({"required": a.required} if hasattr(a, "required") else {}), } for a in p.arguments ] prompts.append(entry) return json.dumps({"prompts": prompts}, ensure_ascii=False) def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) try: return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: recovered = _handle_auth_error_and_retry( server_name, exc, _call_once, "prompts/list", ) if recovered is not None: return recovered logger.error( "MCP %s/list_prompts failed: %s", server_name, exc, ) return json.dumps({ "error": _sanitize_error( f"MCP call failed: {type(exc).__name__}: {exc}" ) }, ensure_ascii=False) return _handler def _make_get_prompt_handler(server_name: str, tool_timeout: float): """Return a sync handler that gets a prompt by name from an MCP server.""" def _handler(args: dict, **kwargs) -> str: from tools.registry import tool_error with _lock: server = _servers.get(server_name) if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }, ensure_ascii=False) name = args.get("name") if not name: return tool_error("Missing required parameter 'name'") arguments = args.get("arguments", {}) async def _call(): result = await server.session.get_prompt(name, arguments=arguments) # GetPromptResult has .messages list messages = [] for msg in (result.messages if hasattr(result, "messages") else []): entry = {} if hasattr(msg, "role"): entry["role"] = msg.role if hasattr(msg, "content"): content = msg.content if hasattr(content, "text"): entry["content"] = content.text elif isinstance(content, str): entry["content"] = content else: entry["content"] = str(content) messages.append(entry) resp = {"messages": messages} if hasattr(result, "description") and result.description: resp["description"] = result.description return json.dumps(resp, ensure_ascii=False) def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) try: return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: recovered = _handle_auth_error_and_retry( server_name, exc, _call_once, "prompts/get", ) if recovered is not None: return recovered logger.error( "MCP %s/get_prompt failed: %s", server_name, exc, ) return json.dumps({ "error": _sanitize_error( f"MCP call failed: {type(exc).__name__}: {exc}" ) }, ensure_ascii=False) return _handler def _make_check_fn(server_name: str): """Return a check function that verifies the MCP connection is alive.""" def _check() -> bool: with _lock: server = _servers.get(server_name) return server is not None and server.session is not None return _check # --------------------------------------------------------------------------- # Discovery & registration # --------------------------------------------------------------------------- def _normalize_mcp_input_schema(schema: dict | None) -> dict: """Normalize MCP input schemas for LLM tool-calling compatibility. MCP servers can emit plain JSON Schema with ``definitions`` / ``#/definitions/...`` references. Kimi / Moonshot rejects that form and requires local refs to point into ``#/$defs/...`` instead. Normalize the common draft-07 shape here so MCP tool schemas remain portable across OpenAI-compatible providers. Additional MCP-server robustness repairs applied recursively: * Missing or ``null`` ``type`` on an object-shaped node is coerced to ``"object"`` (some servers omit it). See PR #4897. * When an ``object`` node lacks ``properties``, an empty ``properties`` dict is added so ``required`` entries don't dangle. * ``required`` arrays are pruned to only names that exist in ``properties``; otherwise Google AI Studio / Gemini 400s with ``property is not defined``. See PR #4651. All repairs are provider-agnostic and ideally produce a schema valid on OpenAI, Anthropic, Gemini, and Moonshot in one pass. """ if not schema: return {"type": "object", "properties": {}} def _rewrite_local_refs(node): if isinstance(node, dict): normalized = {} for key, value in node.items(): out_key = "$defs" if key == "definitions" else key normalized[out_key] = _rewrite_local_refs(value) ref = normalized.get("$ref") if isinstance(ref, str) and ref.startswith("#/definitions/"): normalized["$ref"] = "#/$defs/" + ref[len("#/definitions/"):] return normalized if isinstance(node, list): return [_rewrite_local_refs(item) for item in node] return node def _repair_object_shape(node): """Recursively repair object-shaped nodes: fill type, prune required.""" if isinstance(node, list): return [_repair_object_shape(item) for item in node] if not isinstance(node, dict): return node repaired = {k: _repair_object_shape(v) for k, v in node.items()} # Coerce missing / null type when the shape is clearly an object # (has properties or required but no type). if not repaired.get("type") and ( "properties" in repaired or "required" in repaired ): repaired["type"] = "object" if repaired.get("type") == "object": # Ensure properties exists so required can reference it safely if "properties" not in repaired or not isinstance( repaired.get("properties"), dict ): repaired["properties"] = {} if "properties" not in repaired else repaired["properties"] if not isinstance(repaired.get("properties"), dict): repaired["properties"] = {} # Prune required to only include names that exist in properties required = repaired.get("required") if isinstance(required, list): props = repaired.get("properties") or {} valid = [r for r in required if isinstance(r, str) and r in props] if len(valid) != len(required): if valid: repaired["required"] = valid else: repaired.pop("required", None) return repaired normalized = _rewrite_local_refs(schema) normalized = _repair_object_shape(normalized) # Ensure top-level is a well-formed object schema if not isinstance(normalized, dict): return {"type": "object", "properties": {}} if normalized.get("type") == "object" and "properties" not in normalized: normalized = {**normalized, "properties": {}} return normalized def sanitize_mcp_name_component(value: str) -> str: """Return an MCP name component safe for tool and prefix generation. Preserves Hermes's historical behavior of converting hyphens to underscores, and also replaces any other character outside ``[A-Za-z0-9_]`` with ``_`` so generated tool names are compatible with provider validation rules. """ return re.sub(r"[^A-Za-z0-9_]", "_", str(value or "")) def _convert_mcp_schema(server_name: str, mcp_tool) -> dict: """Convert an MCP tool listing to the Hermes registry schema format. Args: server_name: The logical server name for prefixing. mcp_tool: An MCP ``Tool`` object with ``.name``, ``.description``, and ``.inputSchema``. Returns: A dict suitable for ``registry.register(schema=...)``. """ safe_tool_name = sanitize_mcp_name_component(mcp_tool.name) safe_server_name = sanitize_mcp_name_component(server_name) prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}" return { "name": prefixed_name, "description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}", "parameters": _normalize_mcp_input_schema(getattr(mcp_tool, "inputSchema", None)), } def _build_utility_schemas(server_name: str) -> List[dict]: """Build schemas for the MCP utility tools (resources & prompts). Returns a list of (schema, handler_factory_name) tuples encoded as dicts with keys: schema, handler_key. """ safe_name = sanitize_mcp_name_component(server_name) return [ { "schema": { "name": f"mcp_{safe_name}_list_resources", "description": f"List available resources from MCP server '{server_name}'", "parameters": { "type": "object", "properties": {}, }, }, "handler_key": "list_resources", }, { "schema": { "name": f"mcp_{safe_name}_read_resource", "description": f"Read a resource by URI from MCP server '{server_name}'", "parameters": { "type": "object", "properties": { "uri": { "type": "string", "description": "URI of the resource to read", }, }, "required": ["uri"], }, }, "handler_key": "read_resource", }, { "schema": { "name": f"mcp_{safe_name}_list_prompts", "description": f"List available prompts from MCP server '{server_name}'", "parameters": { "type": "object", "properties": {}, }, }, "handler_key": "list_prompts", }, { "schema": { "name": f"mcp_{safe_name}_get_prompt", "description": f"Get a prompt by name from MCP server '{server_name}'", "parameters": { "type": "object", "properties": { "name": { "type": "string", "description": "Name of the prompt to retrieve", }, "arguments": { "type": "object", "description": "Optional arguments to pass to the prompt", "properties": {}, "additionalProperties": True, }, }, "required": ["name"], }, }, "handler_key": "get_prompt", }, ] def _normalize_name_filter(value: Any, label: str) -> set[str]: """Normalize include/exclude config to a set of tool names.""" if value is None: return set() if isinstance(value, str): return {value} if isinstance(value, (list, tuple, set)): return {str(item) for item in value} logger.warning("MCP config %s must be a string or list of strings; ignoring %r", label, value) return set() def _parse_boolish(value: Any, default: bool = True) -> bool: """Parse a bool-like config value with safe fallback.""" if value is None: return default if isinstance(value, bool): return value if isinstance(value, str): lowered = value.strip().lower() if lowered in {"true", "1", "yes", "on"}: return True if lowered in {"false", "0", "no", "off"}: return False logger.warning("MCP config expected a boolean-ish value, got %r; using default=%s", value, default) return default _UTILITY_CAPABILITY_METHODS = { "list_resources": "list_resources", "read_resource": "read_resource", "list_prompts": "list_prompts", "get_prompt": "get_prompt", } def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]: """Select utility schemas based on config and server capabilities.""" tools_filter = config.get("tools") or {} resources_enabled = _parse_boolish(tools_filter.get("resources"), default=True) prompts_enabled = _parse_boolish(tools_filter.get("prompts"), default=True) selected: List[dict] = [] for entry in _build_utility_schemas(server_name): handler_key = entry["handler_key"] if handler_key in {"list_resources", "read_resource"} and not resources_enabled: logger.debug("MCP server '%s': skipping utility '%s' (resources disabled)", server_name, handler_key) continue if handler_key in {"list_prompts", "get_prompt"} and not prompts_enabled: logger.debug("MCP server '%s': skipping utility '%s' (prompts disabled)", server_name, handler_key) continue required_method = _UTILITY_CAPABILITY_METHODS[handler_key] if not hasattr(server.session, required_method): logger.debug( "MCP server '%s': skipping utility '%s' (session lacks %s)", server_name, handler_key, required_method, ) continue selected.append(entry) return selected def _existing_tool_names() -> List[str]: """Return tool names for all currently connected servers.""" names: List[str] = [] for _sname, server in _servers.items(): if hasattr(server, "_registered_tool_names"): names.extend(server._registered_tool_names) continue for mcp_tool in server._tools: schema = _convert_mcp_schema(server.name, mcp_tool) names.append(schema["name"]) return names def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]: """Register tools from an already-connected server into the registry. Handles include/exclude filtering and utility tools. Toolset resolution for ``mcp-{server}`` and raw server-name aliases is derived from the live registry, rather than mutating ``toolsets.TOOLSETS`` at runtime. Used by both initial discovery and dynamic refresh (list_changed). Returns: List of registered prefixed tool names. """ from tools.registry import registry registered_names: List[str] = [] toolset_name = f"mcp-{name}" # Selective tool loading: honour include/exclude lists from config. # Rules (matching issue #690 spec): # tools.include — whitelist: only these tool names are registered # tools.exclude — blacklist: all tools EXCEPT these are registered # include takes precedence over exclude # Neither set → register all tools (backward-compatible default) tools_filter = config.get("tools") or {} include_set = _normalize_name_filter(tools_filter.get("include"), f"mcp_servers.{name}.tools.include") exclude_set = _normalize_name_filter(tools_filter.get("exclude"), f"mcp_servers.{name}.tools.exclude") def _should_register(tool_name: str) -> bool: if include_set: return tool_name in include_set if exclude_set: return tool_name not in exclude_set return True for mcp_tool in server._tools: if not _should_register(mcp_tool.name): logger.debug("MCP server '%s': skipping tool '%s' (filtered by config)", name, mcp_tool.name) continue # Scan tool description for prompt injection patterns _scan_mcp_description(name, mcp_tool.name, mcp_tool.description or "") schema = _convert_mcp_schema(name, mcp_tool) tool_name_prefixed = schema["name"] # Guard against collisions with built-in (non-MCP) tools. existing_toolset = registry.get_toolset_for_tool(tool_name_prefixed) if existing_toolset and not existing_toolset.startswith("mcp-"): logger.warning( "MCP server '%s': tool '%s' (→ '%s') collides with built-in " "tool in toolset '%s' — skipping to preserve built-in", name, mcp_tool.name, tool_name_prefixed, existing_toolset, ) continue registry.register( name=tool_name_prefixed, toolset=toolset_name, schema=schema, handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout), check_fn=_make_check_fn(name), is_async=False, description=schema["description"], ) registered_names.append(tool_name_prefixed) # Register MCP Resources & Prompts utility tools, filtered by config and # only when the server actually supports the corresponding capability. _handler_factories = { "list_resources": _make_list_resources_handler, "read_resource": _make_read_resource_handler, "list_prompts": _make_list_prompts_handler, "get_prompt": _make_get_prompt_handler, } check_fn = _make_check_fn(name) for entry in _select_utility_schemas(name, server, config): schema = entry["schema"] handler_key = entry["handler_key"] handler = _handler_factories[handler_key](name, server.tool_timeout) util_name = schema["name"] # Same collision guard for utility tools. existing_toolset = registry.get_toolset_for_tool(util_name) if existing_toolset and not existing_toolset.startswith("mcp-"): logger.warning( "MCP server '%s': utility tool '%s' collides with built-in " "tool in toolset '%s' — skipping to preserve built-in", name, util_name, existing_toolset, ) continue registry.register( name=util_name, toolset=toolset_name, schema=schema, handler=handler, check_fn=check_fn, is_async=False, description=schema["description"], ) registered_names.append(util_name) if registered_names: registry.register_toolset_alias(name, toolset_name) return registered_names async def _discover_and_register_server(name: str, config: dict) -> List[str]: """Connect to a single MCP server, discover tools, and register them. Returns list of registered tool names. """ connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) server = await asyncio.wait_for( _connect_server(name, config), timeout=connect_timeout, ) with _lock: _servers[name] = server registered_names = _register_server_tools(name, server, config) server._registered_tool_names = list(registered_names) transport_type = "HTTP" if "url" in config else "stdio" logger.info( "MCP server '%s' (%s): registered %d tool(s): %s", name, transport_type, len(registered_names), ", ".join(registered_names), ) return registered_names # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def register_mcp_servers(servers: Dict[str, dict]) -> List[str]: """Connect to explicit MCP servers and register their tools. Idempotent for already-connected server names. Servers with ``enabled: false`` are skipped without disconnecting existing sessions. Args: servers: Mapping of ``{server_name: server_config}``. Returns: List of all currently registered MCP tool names. """ if not _MCP_AVAILABLE: logger.debug("MCP SDK not available -- skipping explicit MCP registration") return [] if not servers: logger.debug("No explicit MCP servers provided") return [] # Only attempt servers that aren't already connected and are enabled # (enabled: false skips the server entirely without removing its config) with _lock: new_servers = { k: v for k, v in servers.items() if k not in _servers and _parse_boolish(v.get("enabled", True), default=True) } if not new_servers: return _existing_tool_names() # Start the background event loop for MCP connections _ensure_mcp_loop() async def _discover_one(name: str, cfg: dict) -> List[str]: """Connect to a single server and return its registered tool names.""" return await _discover_and_register_server(name, cfg) async def _discover_all(): server_names = list(new_servers.keys()) # Connect to all servers in PARALLEL results = await asyncio.gather( *(_discover_one(name, cfg) for name, cfg in new_servers.items()), return_exceptions=True, ) for name, result in zip(server_names, results): if isinstance(result, Exception): command = new_servers.get(name, {}).get("command") logger.warning( "Failed to connect to MCP server '%s'%s: %s", name, f" (command={command})" if command else "", _format_connect_error(result), ) # Per-server timeouts are handled inside _discover_and_register_server. # The outer timeout is generous: 120s total for parallel discovery. _run_on_mcp_loop(_discover_all(), timeout=120) # Log a summary so ACP callers get visibility into what was registered. with _lock: connected = [n for n in new_servers if n in _servers] new_tool_count = sum( len(getattr(_servers[n], "_registered_tool_names", [])) for n in connected ) failed = len(new_servers) - len(connected) if new_tool_count or failed: summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)" if failed: summary += f" ({failed} failed)" logger.info(summary) return _existing_tool_names() def discover_mcp_tools() -> List[str]: """Entry point: load config, connect to MCP servers, register tools. Called from ``model_tools`` after ``discover_builtin_tools()``. Safe to call even when the ``mcp`` package is not installed (returns empty list). Idempotent for already-connected servers. If some servers failed on a previous call, only the missing ones are retried. Returns: List of all registered MCP tool names. """ if not _MCP_AVAILABLE: logger.debug("MCP SDK not available -- skipping MCP tool discovery") return [] servers = _load_mcp_config() if not servers: logger.debug("No MCP servers configured") return [] with _lock: new_server_names = [ name for name, cfg in servers.items() if name not in _servers and _parse_boolish(cfg.get("enabled", True), default=True) ] tool_names = register_mcp_servers(servers) if not new_server_names: return tool_names with _lock: connected_server_names = [name for name in new_server_names if name in _servers] new_tool_count = sum( len(getattr(_servers[name], "_registered_tool_names", [])) for name in connected_server_names ) failed_count = len(new_server_names) - len(connected_server_names) if new_tool_count or failed_count: summary = f" MCP: {new_tool_count} tool(s) from {len(connected_server_names)} server(s)" if failed_count: summary += f" ({failed_count} failed)" logger.info(summary) return tool_names def get_mcp_status() -> List[dict]: """Return status of all configured MCP servers for banner display. Returns a list of dicts with keys: name, transport, tools, connected. Includes both successfully connected servers and configured-but-failed ones. """ result: List[dict] = [] # Get configured servers from config configured = _load_mcp_config() if not configured: return result with _lock: active_servers = dict(_servers) for name, cfg in configured.items(): transport = "http" if "url" in cfg else "stdio" server = active_servers.get(name) if server and server.session is not None: entry = { "name": name, "transport": transport, "tools": len(server._registered_tool_names) if hasattr(server, "_registered_tool_names") else len(server._tools), "connected": True, } if server._sampling: entry["sampling"] = dict(server._sampling.metrics) result.append(entry) else: result.append({ "name": name, "transport": transport, "tools": 0, "connected": False, }) return result def probe_mcp_server_tools() -> Dict[str, List[tuple]]: """Temporarily connect to configured MCP servers and list their tools. Designed for ``hermes tools`` interactive configuration — connects to each enabled server, grabs tool names and descriptions, then disconnects. Does NOT register tools in the Hermes registry. Returns: Dict mapping server name to list of (tool_name, description) tuples. Servers that fail to connect are omitted from the result. """ if not _MCP_AVAILABLE: return {} servers_config = _load_mcp_config() if not servers_config: return {} enabled = { k: v for k, v in servers_config.items() if _parse_boolish(v.get("enabled", True), default=True) } if not enabled: return {} _ensure_mcp_loop() result: Dict[str, List[tuple]] = {} probed_servers: List[MCPServerTask] = [] async def _probe_all(): names = list(enabled.keys()) coros = [] for name, cfg in enabled.items(): ct = cfg.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) coros.append(asyncio.wait_for(_connect_server(name, cfg), timeout=ct)) outcomes = await asyncio.gather(*coros, return_exceptions=True) for name, outcome in zip(names, outcomes): if isinstance(outcome, Exception): logger.debug("Probe: failed to connect to '%s': %s", name, outcome) continue probed_servers.append(outcome) tools = [] for t in outcome._tools: desc = getattr(t, "description", "") or "" tools.append((t.name, desc)) result[name] = tools # Shut down all probed connections await asyncio.gather( *(s.shutdown() for s in probed_servers), return_exceptions=True, ) try: _run_on_mcp_loop(_probe_all(), timeout=120) except Exception as exc: logger.debug("MCP probe failed: %s", exc) finally: _stop_mcp_loop() return result def shutdown_mcp_servers(): """Close all MCP server connections and stop the background loop. Each server Task is signalled to exit its ``async with`` block so that the anyio cancel-scope cleanup happens in the same Task that opened it. All servers are shut down in parallel via ``asyncio.gather``. """ with _lock: servers_snapshot = list(_servers.values()) # Fast path: nothing to shut down. if not servers_snapshot: _stop_mcp_loop() return async def _shutdown(): results = await asyncio.gather( *(server.shutdown() for server in servers_snapshot), return_exceptions=True, ) for server, result in zip(servers_snapshot, results): if isinstance(result, Exception): logger.debug( "Error closing MCP server '%s': %s", server.name, result, ) with _lock: _servers.clear() with _lock: loop = _mcp_loop if loop is not None and loop.is_running(): try: future = asyncio.run_coroutine_threadsafe(_shutdown(), loop) future.result(timeout=15) except Exception as exc: logger.debug("Error during MCP shutdown: %s", exc) _stop_mcp_loop() def _kill_orphaned_mcp_children() -> None: """Graceful shutdown of MCP stdio subprocesses that survived loop cleanup. Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL. This prevents shared-resource collisions when multiple hermes processes run on the same host (each has its own _stdio_pids dict). Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children. """ import signal as _signal import time as _time with _lock: pids = dict(_stdio_pids) _stdio_pids.clear() # Fast path: no tracked stdio PIDs to reap. Skip the SIGTERM/sleep/SIGKILL # dance entirely — otherwise every MCP-free shutdown pays a 2s sleep tax. if not pids: return # Phase 1: SIGTERM (graceful) for pid, server_name in pids.items(): try: os.kill(pid, _signal.SIGTERM) logger.debug("Sent SIGTERM to orphaned MCP process %d (%s)", pid, server_name) except (ProcessLookupError, PermissionError, OSError): pass # Phase 2: Wait for graceful exit _time.sleep(2) # Phase 3: SIGKILL any survivors _sigkill = getattr(_signal, "SIGKILL", _signal.SIGTERM) for pid, server_name in pids.items(): try: os.kill(pid, 0) # Check if still alive os.kill(pid, _sigkill) logger.warning( "Force-killed MCP process %d (%s) after SIGTERM timeout", pid, server_name, ) except (ProcessLookupError, PermissionError, OSError): pass # Good — exited after SIGTERM def _stop_mcp_loop(): """Stop the background event loop and join its thread.""" global _mcp_loop, _mcp_thread with _lock: loop = _mcp_loop thread = _mcp_thread _mcp_loop = None _mcp_thread = None if loop is not None: loop.call_soon_threadsafe(loop.stop) if thread is not None: thread.join(timeout=5) try: loop.close() except Exception: pass # After closing the loop, any stdio subprocesses that survived the # graceful shutdown are now orphaned. Force-kill them. _kill_orphaned_mcp_children()