mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Fixes a broader class of 'tools.function.parameters is not a valid moonshot flavored json schema' errors on Nous / OpenRouter aggregators routing to moonshotai/kimi-k2.6 with MCP tools loaded. ## Moonshot sanitizer (agent/moonshot_schema.py, new) Model-name-routed (not base-URL-routed) so Nous / OpenRouter users are covered alongside api.moonshot.ai. Applied in ChatCompletionsTransport.build_kwargs when is_moonshot_model(model). Two repairs: 1. Fill missing 'type' on every property / items / anyOf-child schema node (structural walk — only schema-position dicts are touched, not container maps like properties/$defs). 2. Strip 'type' at anyOf parents; Moonshot rejects it. ## MCP normalizer hardened (tools/mcp_tool.py) Draft-07 $ref rewrite from PR #14802 now also does: - coerce missing / null 'type' on object-shaped nodes (salvages #4897) - prune 'required' arrays to names that exist in 'properties' (salvages #4651; Gemini 400s on dangling required) - apply recursively, not just top-level These repairs are provider-agnostic so the same MCP schema is valid on OpenAI, Anthropic, Gemini, and Moonshot in one pass. ## Crash fix: safe getattr for Tool.inputSchema _convert_mcp_schema now uses getattr(t, 'inputSchema', None) so MCP servers whose Tool objects omit the attribute entirely no longer abort registration (salvages #3882). ## Validation - tests/agent/test_moonshot_schema.py: 27 new tests (model detection, missing-type fill, anyOf-parent strip, non-mutation, real-world MCP shape) - tests/tools/test_mcp_tool.py: 7 new tests (missing / null type, required pruning, nested repair, safe getattr) - tests/agent/transports/test_chat_completions.py: 2 new integration tests (Moonshot route sanitizes, non-Moonshot route doesn't) - Targeted suite: 49 passed - E2E via execute_code with a realistic MCP tool carrying all three Moonshot rejection modes + dangling required + draft-07 refs: sanitizer produces a schema valid on Moonshot and Gemini
2759 lines
106 KiB
Python
2759 lines
106 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MCP (Model Context Protocol) Client Support
|
|
|
|
Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport,
|
|
discovers their tools, and registers them into the hermes-agent tool registry
|
|
so the agent can call them like any built-in tool.
|
|
|
|
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
|
The ``mcp`` Python package is optional -- if not installed, this module is a
|
|
no-op and logs a debug message.
|
|
|
|
Example config::
|
|
|
|
mcp_servers:
|
|
filesystem:
|
|
command: "npx"
|
|
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
|
env: {}
|
|
timeout: 120 # per-tool-call timeout in seconds (default: 120)
|
|
connect_timeout: 60 # initial connection timeout (default: 60)
|
|
github:
|
|
command: "npx"
|
|
args: ["-y", "@modelcontextprotocol/server-github"]
|
|
env:
|
|
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
|
remote_api:
|
|
url: "https://my-mcp-server.example.com/mcp"
|
|
headers:
|
|
Authorization: "Bearer sk-..."
|
|
timeout: 180
|
|
analysis:
|
|
command: "npx"
|
|
args: ["-y", "analysis-server"]
|
|
sampling: # server-initiated LLM requests
|
|
enabled: true # default: true
|
|
model: "gemini-3-flash" # override model (optional)
|
|
max_tokens_cap: 4096 # max tokens per request
|
|
timeout: 30 # LLM call timeout (seconds)
|
|
max_rpm: 10 # max requests per minute
|
|
allowed_models: [] # model whitelist (empty = all)
|
|
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
|
log_level: "info" # audit verbosity
|
|
|
|
Features:
|
|
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
|
|
- Automatic reconnection with exponential backoff (up to 5 retries)
|
|
- Environment variable filtering for stdio subprocesses (security)
|
|
- Credential stripping in error messages returned to the LLM
|
|
- Configurable per-server timeouts for tool calls and connections
|
|
- Thread-safe architecture with dedicated background event loop
|
|
- Sampling support: MCP servers can request LLM completions via
|
|
sampling/createMessage (text and tool-use responses)
|
|
|
|
Architecture:
|
|
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
|
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
|
|
its transport context alive. Tool call coroutines are scheduled onto the
|
|
loop via ``run_coroutine_threadsafe()``.
|
|
|
|
On shutdown, each server Task is signalled to exit its ``async with``
|
|
block, ensuring the anyio cancel-scope cleanup happens in the *same*
|
|
Task that opened the connection (required by anyio).
|
|
|
|
Thread safety:
|
|
_servers and _mcp_loop/_mcp_thread are accessed from both the MCP
|
|
background thread and caller threads. All mutations are protected by
|
|
_lock so the code is safe regardless of GIL presence (e.g. Python 3.13+
|
|
free-threading).
|
|
"""
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
import shutil
|
|
import threading
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Graceful import -- MCP SDK is an optional dependency
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_MCP_AVAILABLE = False
|
|
_MCP_HTTP_AVAILABLE = False
|
|
_MCP_SAMPLING_TYPES = False
|
|
_MCP_NOTIFICATION_TYPES = False
|
|
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
|
try:
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
_MCP_AVAILABLE = True
|
|
try:
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
_MCP_HTTP_AVAILABLE = True
|
|
except ImportError:
|
|
_MCP_HTTP_AVAILABLE = False
|
|
# Prefer the non-deprecated API (mcp >= 1.24.0); fall back to the
|
|
# deprecated wrapper for older SDK versions.
|
|
try:
|
|
from mcp.client.streamable_http import streamable_http_client
|
|
_MCP_NEW_HTTP = True
|
|
except ImportError:
|
|
_MCP_NEW_HTTP = False
|
|
# Sampling types -- separated so older SDK versions don't break MCP support
|
|
try:
|
|
from mcp.types import (
|
|
CreateMessageResult,
|
|
CreateMessageResultWithTools,
|
|
ErrorData,
|
|
SamplingCapability,
|
|
SamplingToolsCapability,
|
|
TextContent,
|
|
ToolUseContent,
|
|
)
|
|
_MCP_SAMPLING_TYPES = True
|
|
except ImportError:
|
|
logger.debug("MCP sampling types not available -- sampling disabled")
|
|
# Notification types for dynamic tool discovery (tools/list_changed)
|
|
try:
|
|
from mcp.types import (
|
|
ServerNotification,
|
|
ToolListChangedNotification,
|
|
PromptListChangedNotification,
|
|
ResourceListChangedNotification,
|
|
)
|
|
_MCP_NOTIFICATION_TYPES = True
|
|
except ImportError:
|
|
logger.debug("MCP notification types not available -- dynamic tool discovery disabled")
|
|
except ImportError:
|
|
logger.debug("mcp package not installed -- MCP tool support disabled")
|
|
|
|
|
|
def _check_message_handler_support() -> bool:
|
|
"""Check if ClientSession accepts ``message_handler`` kwarg.
|
|
|
|
Inspects the constructor signature for backward compatibility with older
|
|
MCP SDK versions that don't support notification handlers.
|
|
"""
|
|
if not _MCP_AVAILABLE:
|
|
return False
|
|
try:
|
|
return "message_handler" in inspect.signature(ClientSession).parameters
|
|
except (TypeError, ValueError):
|
|
return False
|
|
|
|
|
|
_MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support()
|
|
if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED:
|
|
logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
|
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
|
_MAX_RECONNECT_RETRIES = 5
|
|
_MAX_INITIAL_CONNECT_RETRIES = 3 # retries for the very first connection attempt
|
|
_MAX_BACKOFF_SECONDS = 60
|
|
|
|
# Environment variables that are safe to pass to stdio subprocesses
|
|
_SAFE_ENV_KEYS = frozenset({
|
|
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
|
|
})
|
|
|
|
# Regex for credential patterns to strip from error messages
|
|
_CREDENTIAL_PATTERN = re.compile(
|
|
r"(?:"
|
|
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
|
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
|
r"|Bearer\s+\S+" # Bearer token
|
|
r"|token=[^\s&,;\"']{1,255}" # token=...
|
|
r"|key=[^\s&,;\"']{1,255}" # key=...
|
|
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
|
r"|password=[^\s&,;\"']{1,255}" # password=...
|
|
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
|
r")",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Security helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _build_safe_env(user_env: Optional[dict]) -> dict:
|
|
"""Build a filtered environment dict for stdio subprocesses.
|
|
|
|
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
|
|
variables from the current process environment, plus any variables
|
|
explicitly specified by the user in the server config.
|
|
|
|
This prevents accidentally leaking secrets like API keys, tokens, or
|
|
credentials to MCP server subprocesses.
|
|
"""
|
|
env = {}
|
|
for key, value in os.environ.items():
|
|
if key in _SAFE_ENV_KEYS or key.startswith("XDG_"):
|
|
env[key] = value
|
|
if user_env:
|
|
env.update(user_env)
|
|
return env
|
|
|
|
|
|
def _sanitize_error(text: str) -> str:
|
|
"""Strip credential-like patterns from error text before returning to LLM.
|
|
|
|
Replaces tokens, keys, and other secrets with [REDACTED] to prevent
|
|
accidental credential exposure in tool error responses.
|
|
"""
|
|
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MCP tool description content scanning
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Patterns that indicate potential prompt injection in MCP tool descriptions.
|
|
# These are WARNING-level — we log but don't block, since false positives
|
|
# would break legitimate MCP servers.
|
|
_MCP_INJECTION_PATTERNS = [
|
|
(re.compile(r"ignore\s+(all\s+)?previous\s+instructions", re.I),
|
|
"prompt override attempt ('ignore previous instructions')"),
|
|
(re.compile(r"you\s+are\s+now\s+a", re.I),
|
|
"identity override attempt ('you are now a...')"),
|
|
(re.compile(r"your\s+new\s+(task|role|instructions?)\s+(is|are)", re.I),
|
|
"task override attempt"),
|
|
(re.compile(r"system\s*:\s*", re.I),
|
|
"system prompt injection attempt"),
|
|
(re.compile(r"<\s*(system|human|assistant)\s*>", re.I),
|
|
"role tag injection attempt"),
|
|
(re.compile(r"do\s+not\s+(tell|inform|mention|reveal)", re.I),
|
|
"concealment instruction"),
|
|
(re.compile(r"(curl|wget|fetch)\s+https?://", re.I),
|
|
"network command in description"),
|
|
(re.compile(r"base64\.(b64decode|decodebytes)", re.I),
|
|
"base64 decode reference"),
|
|
(re.compile(r"exec\s*\(|eval\s*\(", re.I),
|
|
"code execution reference"),
|
|
(re.compile(r"import\s+(subprocess|os|shutil|socket)", re.I),
|
|
"dangerous import reference"),
|
|
]
|
|
|
|
|
|
def _scan_mcp_description(server_name: str, tool_name: str, description: str) -> List[str]:
|
|
"""Scan an MCP tool description for prompt injection patterns.
|
|
|
|
Returns a list of finding strings (empty = clean).
|
|
"""
|
|
findings = []
|
|
if not description:
|
|
return findings
|
|
for pattern, reason in _MCP_INJECTION_PATTERNS:
|
|
if pattern.search(description):
|
|
findings.append(reason)
|
|
if findings:
|
|
logger.warning(
|
|
"MCP server '%s' tool '%s': suspicious description content — %s. "
|
|
"Description: %.200s",
|
|
server_name, tool_name, "; ".join(findings),
|
|
description,
|
|
)
|
|
return findings
|
|
|
|
|
|
def _prepend_path(env: dict, directory: str) -> dict:
|
|
"""Prepend *directory* to env PATH if it is not already present."""
|
|
updated = dict(env or {})
|
|
if not directory:
|
|
return updated
|
|
|
|
existing = updated.get("PATH", "")
|
|
parts = [part for part in existing.split(os.pathsep) if part]
|
|
if directory not in parts:
|
|
parts = [directory, *parts]
|
|
updated["PATH"] = os.pathsep.join(parts) if parts else directory
|
|
return updated
|
|
|
|
|
|
def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]:
|
|
"""Resolve a stdio MCP command against the exact subprocess environment.
|
|
|
|
This primarily exists to make bare ``npx``/``npm``/``node`` commands work
|
|
reliably even when MCP subprocesses run under a filtered PATH.
|
|
"""
|
|
resolved_command = os.path.expanduser(str(command).strip())
|
|
resolved_env = dict(env or {})
|
|
|
|
if os.sep not in resolved_command:
|
|
path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None
|
|
which_hit = shutil.which(resolved_command, path=path_arg)
|
|
if which_hit:
|
|
resolved_command = which_hit
|
|
elif resolved_command in {"npx", "npm", "node"}:
|
|
hermes_home = os.path.expanduser(
|
|
os.getenv(
|
|
"HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes")
|
|
)
|
|
)
|
|
candidates = [
|
|
os.path.join(hermes_home, "node", "bin", resolved_command),
|
|
os.path.join(os.path.expanduser("~"), ".local", "bin", resolved_command),
|
|
]
|
|
for candidate in candidates:
|
|
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
|
resolved_command = candidate
|
|
break
|
|
|
|
command_dir = os.path.dirname(resolved_command)
|
|
if command_dir:
|
|
resolved_env = _prepend_path(resolved_env, command_dir)
|
|
|
|
return resolved_command, resolved_env
|
|
|
|
|
|
def _format_connect_error(exc: BaseException) -> str:
|
|
"""Render nested MCP connection errors into an actionable short message."""
|
|
|
|
def _find_missing(current: BaseException) -> Optional[str]:
|
|
nested = getattr(current, "exceptions", None)
|
|
if nested:
|
|
for child in nested:
|
|
missing = _find_missing(child)
|
|
if missing:
|
|
return missing
|
|
return None
|
|
if isinstance(current, FileNotFoundError):
|
|
if getattr(current, "filename", None):
|
|
return str(current.filename)
|
|
match = re.search(r"No such file or directory: '([^']+)'", str(current))
|
|
if match:
|
|
return match.group(1)
|
|
for attr in ("__cause__", "__context__"):
|
|
nested_exc = getattr(current, attr, None)
|
|
if isinstance(nested_exc, BaseException):
|
|
missing = _find_missing(nested_exc)
|
|
if missing:
|
|
return missing
|
|
return None
|
|
|
|
def _flatten_messages(current: BaseException) -> List[str]:
|
|
nested = getattr(current, "exceptions", None)
|
|
if nested:
|
|
flattened: List[str] = []
|
|
for child in nested:
|
|
flattened.extend(_flatten_messages(child))
|
|
return flattened
|
|
messages = []
|
|
text = str(current).strip()
|
|
if text:
|
|
messages.append(text)
|
|
for attr in ("__cause__", "__context__"):
|
|
nested_exc = getattr(current, attr, None)
|
|
if isinstance(nested_exc, BaseException):
|
|
messages.extend(_flatten_messages(nested_exc))
|
|
return messages or [current.__class__.__name__]
|
|
|
|
missing = _find_missing(exc)
|
|
if missing:
|
|
message = f"missing executable '{missing}'"
|
|
if os.path.basename(missing) in {"npx", "npm", "node"}:
|
|
message += (
|
|
" (ensure Node.js is installed and PATH includes its bin directory, "
|
|
"or set mcp_servers.<name>.command to an absolute path and include "
|
|
"that directory in mcp_servers.<name>.env.PATH)"
|
|
)
|
|
return _sanitize_error(message)
|
|
|
|
deduped: List[str] = []
|
|
for item in _flatten_messages(exc):
|
|
if item not in deduped:
|
|
deduped.append(item)
|
|
return _sanitize_error("; ".join(deduped[:3]))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _safe_numeric(value, default, coerce=int, minimum=1):
|
|
"""Coerce a config value to a numeric type, returning *default* on failure.
|
|
|
|
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
|
|
non-finite floats, and values below *minimum*.
|
|
"""
|
|
try:
|
|
result = coerce(value)
|
|
if isinstance(result, float) and not math.isfinite(result):
|
|
return default
|
|
return max(result, minimum)
|
|
except (TypeError, ValueError, OverflowError):
|
|
return default
|
|
|
|
|
|
class SamplingHandler:
|
|
"""Handles sampling/createMessage requests for a single MCP server.
|
|
|
|
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
|
|
The handler is callable and passed directly to ``ClientSession`` as
|
|
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
|
|
tool-loop counters) lives on the instance -- no module-level globals.
|
|
|
|
The callback is async and runs on the MCP background event loop. The
|
|
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
|
|
it doesn't block the event loop.
|
|
"""
|
|
|
|
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
|
|
|
|
def __init__(self, server_name: str, config: dict):
|
|
self.server_name = server_name
|
|
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
|
|
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
|
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
|
self.max_tool_rounds = _safe_numeric(
|
|
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
|
)
|
|
self.model_override = config.get("model")
|
|
self.allowed_models = config.get("allowed_models", [])
|
|
|
|
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
|
self.audit_level = _log_levels.get(
|
|
str(config.get("log_level", "info")).lower(), logging.INFO,
|
|
)
|
|
|
|
# Per-instance state
|
|
self._rate_timestamps: List[float] = []
|
|
self._tool_loop_count = 0
|
|
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
|
|
|
# -- Rate limiting -------------------------------------------------------
|
|
|
|
def _check_rate_limit(self) -> bool:
|
|
"""Sliding-window rate limiter. Returns True if request is allowed."""
|
|
now = time.time()
|
|
window = now - 60
|
|
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
|
|
if len(self._rate_timestamps) >= self.max_rpm:
|
|
return False
|
|
self._rate_timestamps.append(now)
|
|
return True
|
|
|
|
# -- Model resolution ----------------------------------------------------
|
|
|
|
def _resolve_model(self, preferences) -> Optional[str]:
|
|
"""Config override > server hint > None (use default)."""
|
|
if self.model_override:
|
|
return self.model_override
|
|
if preferences and hasattr(preferences, "hints") and preferences.hints:
|
|
for hint in preferences.hints:
|
|
if hasattr(hint, "name") and hint.name:
|
|
return hint.name
|
|
return None
|
|
|
|
# -- Message conversion --------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _extract_tool_result_text(block) -> str:
|
|
"""Extract text from a ToolResultContent block."""
|
|
if not hasattr(block, "content") or block.content is None:
|
|
return ""
|
|
items = block.content if isinstance(block.content, list) else [block.content]
|
|
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
|
|
|
def _convert_messages(self, params) -> List[dict]:
|
|
"""Convert MCP SamplingMessages to OpenAI format.
|
|
|
|
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
|
list-of-blocks are handled uniformly. Dispatches per block type
|
|
with ``isinstance`` on real SDK types when available, falling back
|
|
to duck-typing via ``hasattr`` for compatibility.
|
|
"""
|
|
messages: List[dict] = []
|
|
for msg in params.messages:
|
|
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
|
msg.content if isinstance(msg.content, list) else [msg.content]
|
|
)
|
|
|
|
# Separate blocks by kind
|
|
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
|
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
|
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
|
|
|
# Emit tool result messages (role: tool)
|
|
for tr in tool_results:
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.toolUseId,
|
|
"content": self._extract_tool_result_text(tr),
|
|
})
|
|
|
|
# Emit assistant tool_calls message
|
|
if tool_uses:
|
|
tc_list = []
|
|
for tu in tool_uses:
|
|
tc_list.append({
|
|
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
|
"type": "function",
|
|
"function": {
|
|
"name": tu.name,
|
|
"arguments": json.dumps(tu.input, ensure_ascii=False) if isinstance(tu.input, dict) else str(tu.input),
|
|
},
|
|
})
|
|
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
|
# Include any accompanying text
|
|
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
|
if text_parts:
|
|
msg_dict["content"] = "\n".join(text_parts)
|
|
messages.append(msg_dict)
|
|
elif content_blocks:
|
|
# Pure text/image content
|
|
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
|
|
messages.append({"role": msg.role, "content": content_blocks[0].text})
|
|
else:
|
|
parts = []
|
|
for block in content_blocks:
|
|
if hasattr(block, "text"):
|
|
parts.append({"type": "text", "text": block.text})
|
|
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
|
parts.append({
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
|
})
|
|
else:
|
|
logger.warning(
|
|
"Unsupported sampling content block type: %s (skipped)",
|
|
type(block).__name__,
|
|
)
|
|
if parts:
|
|
messages.append({"role": msg.role, "content": parts})
|
|
|
|
return messages
|
|
|
|
# -- Error helper --------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _error(message: str, code: int = -1):
|
|
"""Return ErrorData (MCP spec) or raise as fallback."""
|
|
if _MCP_SAMPLING_TYPES:
|
|
return ErrorData(code=code, message=message)
|
|
raise Exception(message)
|
|
|
|
# -- Response building ---------------------------------------------------
|
|
|
|
def _build_tool_use_result(self, choice, response):
|
|
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
|
|
self.metrics["tool_use_count"] += 1
|
|
|
|
# Tool loop governance
|
|
if self.max_tool_rounds == 0:
|
|
self._tool_loop_count = 0
|
|
return self._error(
|
|
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
|
)
|
|
|
|
self._tool_loop_count += 1
|
|
if self._tool_loop_count > self.max_tool_rounds:
|
|
self._tool_loop_count = 0
|
|
return self._error(
|
|
f"Tool loop limit exceeded for server '{self.server_name}' "
|
|
f"(max {self.max_tool_rounds} rounds)"
|
|
)
|
|
|
|
content_blocks = []
|
|
for tc in choice.message.tool_calls:
|
|
args = tc.function.arguments
|
|
if isinstance(args, str):
|
|
try:
|
|
parsed = json.loads(args)
|
|
except (json.JSONDecodeError, ValueError):
|
|
logger.warning(
|
|
"MCP server '%s': malformed tool_calls arguments "
|
|
"from LLM (wrapping as raw): %.100s",
|
|
self.server_name, args,
|
|
)
|
|
parsed = {"_raw": args}
|
|
else:
|
|
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
|
|
|
content_blocks.append(ToolUseContent(
|
|
type="tool_use",
|
|
id=tc.id,
|
|
name=tc.function.name,
|
|
input=parsed,
|
|
))
|
|
|
|
logger.log(
|
|
self.audit_level,
|
|
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
|
self.server_name, response.model,
|
|
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
|
len(content_blocks),
|
|
)
|
|
|
|
return CreateMessageResultWithTools(
|
|
role="assistant",
|
|
content=content_blocks,
|
|
model=response.model,
|
|
stopReason="toolUse",
|
|
)
|
|
|
|
def _build_text_result(self, choice, response):
|
|
"""Build a CreateMessageResult from a normal text response."""
|
|
self._tool_loop_count = 0 # reset on text response
|
|
response_text = choice.message.content or ""
|
|
|
|
logger.log(
|
|
self.audit_level,
|
|
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
|
self.server_name, response.model,
|
|
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
|
)
|
|
|
|
return CreateMessageResult(
|
|
role="assistant",
|
|
content=TextContent(type="text", text=_sanitize_error(response_text)),
|
|
model=response.model,
|
|
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
|
|
)
|
|
|
|
# -- Session kwargs helper -----------------------------------------------
|
|
|
|
def session_kwargs(self) -> dict:
|
|
"""Return kwargs to pass to ClientSession for sampling support."""
|
|
return {
|
|
"sampling_callback": self,
|
|
"sampling_capabilities": SamplingCapability(
|
|
tools=SamplingToolsCapability(),
|
|
),
|
|
}
|
|
|
|
# -- Main callback -------------------------------------------------------
|
|
|
|
async def __call__(self, context, params):
|
|
"""Sampling callback invoked by the MCP SDK.
|
|
|
|
Conforms to ``SamplingFnT`` protocol. Returns
|
|
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
|
|
``ErrorData``.
|
|
"""
|
|
# Rate limit
|
|
if not self._check_rate_limit():
|
|
logger.warning(
|
|
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
|
self.server_name, self.max_rpm,
|
|
)
|
|
self.metrics["errors"] += 1
|
|
return self._error(
|
|
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
|
f"({self.max_rpm} requests/minute)"
|
|
)
|
|
|
|
# Resolve model
|
|
model = self._resolve_model(getattr(params, "modelPreferences", None))
|
|
|
|
# Get auxiliary LLM client via centralized router
|
|
from agent.auxiliary_client import call_llm
|
|
|
|
# Model whitelist check (we need to resolve model before calling)
|
|
resolved_model = model or self.model_override or ""
|
|
|
|
if self.allowed_models and resolved_model and resolved_model not in self.allowed_models:
|
|
logger.warning(
|
|
"MCP server '%s' requested model '%s' not in allowed_models",
|
|
self.server_name, resolved_model,
|
|
)
|
|
self.metrics["errors"] += 1
|
|
return self._error(
|
|
f"Model '{resolved_model}' not allowed for server "
|
|
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
|
|
)
|
|
|
|
# Convert messages
|
|
messages = self._convert_messages(params)
|
|
if hasattr(params, "systemPrompt") and params.systemPrompt:
|
|
messages.insert(0, {"role": "system", "content": params.systemPrompt})
|
|
|
|
# Build LLM call kwargs
|
|
max_tokens = min(params.maxTokens, self.max_tokens_cap)
|
|
call_temperature = None
|
|
if hasattr(params, "temperature") and params.temperature is not None:
|
|
call_temperature = params.temperature
|
|
|
|
# Forward server-provided tools
|
|
call_tools = None
|
|
server_tools = getattr(params, "tools", None)
|
|
if server_tools:
|
|
call_tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": getattr(t, "name", ""),
|
|
"description": getattr(t, "description", "") or "",
|
|
"parameters": _normalize_mcp_input_schema(
|
|
getattr(t, "inputSchema", None)
|
|
),
|
|
},
|
|
}
|
|
for t in server_tools
|
|
]
|
|
|
|
logger.log(
|
|
self.audit_level,
|
|
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
|
self.server_name, resolved_model, max_tokens, len(messages),
|
|
)
|
|
|
|
# Offload sync LLM call to thread (non-blocking)
|
|
def _sync_call():
|
|
return call_llm(
|
|
task="mcp",
|
|
model=resolved_model or None,
|
|
messages=messages,
|
|
temperature=call_temperature,
|
|
max_tokens=max_tokens,
|
|
tools=call_tools,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
try:
|
|
response = await asyncio.wait_for(
|
|
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
self.metrics["errors"] += 1
|
|
return self._error(
|
|
f"Sampling LLM call timed out after {self.timeout}s "
|
|
f"for server '{self.server_name}'"
|
|
)
|
|
except Exception as exc:
|
|
self.metrics["errors"] += 1
|
|
return self._error(
|
|
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
|
)
|
|
|
|
# Guard against empty choices (content filtering, provider errors)
|
|
if not getattr(response, "choices", None):
|
|
self.metrics["errors"] += 1
|
|
return self._error(
|
|
f"LLM returned empty response (no choices) for server "
|
|
f"'{self.server_name}'"
|
|
)
|
|
|
|
# Track metrics
|
|
choice = response.choices[0]
|
|
self.metrics["requests"] += 1
|
|
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
|
|
if isinstance(total_tokens, int):
|
|
self.metrics["tokens_used"] += total_tokens
|
|
|
|
# Dispatch based on response type
|
|
if (
|
|
choice.finish_reason == "tool_calls"
|
|
and hasattr(choice.message, "tool_calls")
|
|
and choice.message.tool_calls
|
|
):
|
|
return self._build_tool_use_result(choice, response)
|
|
|
|
return self._build_text_result(choice, response)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Server task -- each MCP server lives in one long-lived asyncio Task
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class MCPServerTask:
|
|
"""Manages a single MCP server connection in a dedicated asyncio Task.
|
|
|
|
The entire connection lifecycle (connect, discover, serve, disconnect)
|
|
runs inside one asyncio Task so that anyio cancel-scopes created by
|
|
the transport client are entered and exited in the same Task context.
|
|
|
|
Supports both stdio and HTTP/StreamableHTTP transports.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"name", "session", "tool_timeout",
|
|
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
|
"_tools", "_error", "_config",
|
|
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
|
)
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
self.session: Optional[Any] = None
|
|
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
|
|
self._task: Optional[asyncio.Task] = None
|
|
self._ready = asyncio.Event()
|
|
self._shutdown_event = asyncio.Event()
|
|
# Set by tool handlers on auth failure after manager.handle_401()
|
|
# confirms recovery is viable. When set, _run_http / _run_stdio
|
|
# exit their async-with blocks cleanly (no exception), and the
|
|
# outer run() loop re-enters the transport so the MCP session is
|
|
# rebuilt with fresh credentials.
|
|
self._reconnect_event = asyncio.Event()
|
|
self._tools: list = []
|
|
self._error: Optional[Exception] = None
|
|
self._config: dict = {}
|
|
self._sampling: Optional[SamplingHandler] = None
|
|
self._registered_tool_names: list[str] = []
|
|
self._auth_type: str = ""
|
|
self._refresh_lock = asyncio.Lock()
|
|
|
|
def _is_http(self) -> bool:
|
|
"""Check if this server uses HTTP transport."""
|
|
return "url" in self._config
|
|
|
|
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
|
|
|
|
def _make_message_handler(self):
|
|
"""Build a ``message_handler`` callback for ``ClientSession``.
|
|
|
|
Dispatches on notification type. Only ``ToolListChangedNotification``
|
|
triggers a refresh; prompt and resource change notifications are
|
|
logged as stubs for future work.
|
|
"""
|
|
async def _handler(message):
|
|
try:
|
|
if isinstance(message, Exception):
|
|
logger.debug("MCP message handler (%s): exception: %s", self.name, message)
|
|
return
|
|
if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification):
|
|
match message.root:
|
|
case ToolListChangedNotification():
|
|
logger.info(
|
|
"MCP server '%s': received tools/list_changed notification",
|
|
self.name,
|
|
)
|
|
await self._refresh_tools()
|
|
case PromptListChangedNotification():
|
|
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
|
case ResourceListChangedNotification():
|
|
logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name)
|
|
case _:
|
|
pass
|
|
except Exception:
|
|
logger.exception("Error in MCP message handler for '%s'", self.name)
|
|
return _handler
|
|
|
|
async def _refresh_tools(self):
|
|
"""Re-fetch tools from the server and update the registry.
|
|
|
|
Called when the server sends ``notifications/tools/list_changed``.
|
|
The lock prevents overlapping refreshes from rapid-fire notifications.
|
|
After the initial ``await`` (list_tools), all mutations are synchronous
|
|
— atomic from the event loop's perspective.
|
|
"""
|
|
from tools.registry import registry
|
|
|
|
async with self._refresh_lock:
|
|
# Capture old tool names for change diff
|
|
old_tool_names = set(self._registered_tool_names)
|
|
|
|
# 1. Fetch current tool list from server
|
|
tools_result = await self.session.list_tools()
|
|
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
|
|
|
# 2. Deregister old tools from the central registry
|
|
for prefixed_name in self._registered_tool_names:
|
|
registry.deregister(prefixed_name)
|
|
|
|
# 3. Re-register with fresh tool list
|
|
self._tools = new_mcp_tools
|
|
self._registered_tool_names = _register_server_tools(
|
|
self.name, self, self._config
|
|
)
|
|
|
|
# 5. Log what changed (user-visible notification)
|
|
new_tool_names = set(self._registered_tool_names)
|
|
added = new_tool_names - old_tool_names
|
|
removed = old_tool_names - new_tool_names
|
|
changes = []
|
|
if added:
|
|
changes.append(f"added: {', '.join(sorted(added))}")
|
|
if removed:
|
|
changes.append(f"removed: {', '.join(sorted(removed))}")
|
|
if changes:
|
|
logger.warning(
|
|
"MCP server '%s': tools changed dynamically — %s. "
|
|
"Verify these changes are expected.",
|
|
self.name, "; ".join(changes),
|
|
)
|
|
else:
|
|
logger.info(
|
|
"MCP server '%s': dynamically refreshed %d tool(s) (no changes)",
|
|
self.name, len(self._registered_tool_names),
|
|
)
|
|
|
|
async def _wait_for_lifecycle_event(self) -> str:
|
|
"""Block until either _shutdown_event or _reconnect_event fires.
|
|
|
|
Returns:
|
|
"shutdown" if the server should exit the run loop entirely.
|
|
"reconnect" if the server should tear down the current MCP
|
|
session and re-enter the transport (fresh OAuth
|
|
tokens, new session ID, etc.). The reconnect event
|
|
is cleared before return so the next cycle starts
|
|
with a fresh signal.
|
|
|
|
Shutdown takes precedence if both events are set simultaneously.
|
|
"""
|
|
shutdown_task = asyncio.create_task(self._shutdown_event.wait())
|
|
reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
|
try:
|
|
await asyncio.wait(
|
|
{shutdown_task, reconnect_task},
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
finally:
|
|
for t in (shutdown_task, reconnect_task):
|
|
if not t.done():
|
|
t.cancel()
|
|
try:
|
|
await t
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
if self._shutdown_event.is_set():
|
|
return "shutdown"
|
|
self._reconnect_event.clear()
|
|
return "reconnect"
|
|
|
|
async def _run_stdio(self, config: dict):
|
|
"""Run the server using stdio transport."""
|
|
command = config.get("command")
|
|
args = config.get("args", [])
|
|
user_env = config.get("env")
|
|
|
|
if not command:
|
|
raise ValueError(
|
|
f"MCP server '{self.name}' has no 'command' in config"
|
|
)
|
|
|
|
safe_env = _build_safe_env(user_env)
|
|
command, safe_env = _resolve_stdio_command(command, safe_env)
|
|
|
|
# Check package against OSV malware database before spawning
|
|
from tools.osv_check import check_package_for_malware
|
|
malware_error = check_package_for_malware(command, args)
|
|
if malware_error:
|
|
raise ValueError(
|
|
f"MCP server '{self.name}': {malware_error}"
|
|
)
|
|
|
|
server_params = StdioServerParameters(
|
|
command=command,
|
|
args=args,
|
|
env=safe_env if safe_env else None,
|
|
)
|
|
|
|
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
|
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
|
sampling_kwargs["message_handler"] = self._make_message_handler()
|
|
|
|
# Snapshot child PIDs before spawning so we can track the new one.
|
|
pids_before = _snapshot_child_pids()
|
|
async with stdio_client(server_params) as (read_stream, write_stream):
|
|
# Capture the newly spawned subprocess PID for force-kill cleanup.
|
|
new_pids = _snapshot_child_pids() - pids_before
|
|
if new_pids:
|
|
with _lock:
|
|
for _pid in new_pids:
|
|
_stdio_pids[_pid] = self.name
|
|
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
|
await session.initialize()
|
|
self.session = session
|
|
await self._discover_tools()
|
|
self._ready.set()
|
|
# stdio transport does not use OAuth, but we still honor
|
|
# _reconnect_event (e.g. future manual /mcp refresh) for
|
|
# consistency with _run_http.
|
|
await self._wait_for_lifecycle_event()
|
|
# Context exited cleanly — subprocess was terminated by the SDK.
|
|
if new_pids:
|
|
with _lock:
|
|
for _pid in new_pids:
|
|
_stdio_pids.pop(_pid, None)
|
|
|
|
async def _run_http(self, config: dict):
|
|
"""Run the server using HTTP/StreamableHTTP transport."""
|
|
if not _MCP_HTTP_AVAILABLE:
|
|
raise ImportError(
|
|
f"MCP server '{self.name}' requires HTTP transport but "
|
|
"mcp.client.streamable_http is not available. "
|
|
"Upgrade the mcp package to get HTTP support."
|
|
)
|
|
|
|
url = config["url"]
|
|
headers = dict(config.get("headers") or {})
|
|
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
|
ssl_verify = config.get("ssl_verify", True)
|
|
|
|
# OAuth 2.1 PKCE: route through the central MCPOAuthManager so the
|
|
# same provider instance is reused across reconnects, pre-flow
|
|
# disk-watch is active, and config-time CLI code paths share state.
|
|
# If OAuth setup fails (e.g. non-interactive env without cached
|
|
# tokens), re-raise so this server is reported as failed without
|
|
# blocking other MCP servers from connecting.
|
|
_oauth_auth = None
|
|
if self._auth_type == "oauth":
|
|
try:
|
|
from tools.mcp_oauth_manager import get_manager
|
|
_oauth_auth = get_manager().get_or_build_provider(
|
|
self.name, url, config.get("oauth"),
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
|
raise
|
|
|
|
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
|
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
|
sampling_kwargs["message_handler"] = self._make_message_handler()
|
|
|
|
if _MCP_NEW_HTTP:
|
|
# New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient
|
|
# matching the SDK's own create_mcp_http_client defaults.
|
|
import httpx
|
|
|
|
client_kwargs: dict = {
|
|
"follow_redirects": True,
|
|
"timeout": httpx.Timeout(float(connect_timeout), read=300.0),
|
|
"verify": ssl_verify,
|
|
}
|
|
if headers:
|
|
client_kwargs["headers"] = headers
|
|
if _oauth_auth is not None:
|
|
client_kwargs["auth"] = _oauth_auth
|
|
|
|
# Caller owns the client lifecycle — the SDK skips cleanup when
|
|
# http_client is provided, so we wrap in async-with.
|
|
async with httpx.AsyncClient(**client_kwargs) as http_client:
|
|
async with streamable_http_client(url, http_client=http_client) as (
|
|
read_stream, write_stream, _get_session_id,
|
|
):
|
|
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
|
await session.initialize()
|
|
self.session = session
|
|
await self._discover_tools()
|
|
self._ready.set()
|
|
reason = await self._wait_for_lifecycle_event()
|
|
if reason == "reconnect":
|
|
logger.info(
|
|
"MCP server '%s': reconnect requested — "
|
|
"tearing down HTTP session", self.name,
|
|
)
|
|
else:
|
|
# Deprecated API (mcp < 1.24.0): manages httpx client internally.
|
|
_http_kwargs: dict = {
|
|
"headers": headers,
|
|
"timeout": float(connect_timeout),
|
|
"verify": ssl_verify,
|
|
}
|
|
if _oauth_auth is not None:
|
|
_http_kwargs["auth"] = _oauth_auth
|
|
async with streamablehttp_client(url, **_http_kwargs) as (
|
|
read_stream, write_stream, _get_session_id,
|
|
):
|
|
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
|
await session.initialize()
|
|
self.session = session
|
|
await self._discover_tools()
|
|
self._ready.set()
|
|
reason = await self._wait_for_lifecycle_event()
|
|
if reason == "reconnect":
|
|
logger.info(
|
|
"MCP server '%s': reconnect requested — "
|
|
"tearing down legacy HTTP session", self.name,
|
|
)
|
|
|
|
async def _discover_tools(self):
|
|
"""Discover tools from the connected session."""
|
|
if self.session is None:
|
|
return
|
|
tools_result = await self.session.list_tools()
|
|
self._tools = (
|
|
tools_result.tools
|
|
if hasattr(tools_result, "tools")
|
|
else []
|
|
)
|
|
|
|
async def run(self, config: dict):
|
|
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
|
|
|
|
Includes automatic reconnection with exponential backoff if the
|
|
connection drops unexpectedly (unless shutdown was requested).
|
|
"""
|
|
self._config = config
|
|
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
|
self._auth_type = (config.get("auth") or "").lower().strip()
|
|
|
|
# Set up sampling handler if enabled and SDK types are available
|
|
sampling_config = config.get("sampling", {})
|
|
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
|
self._sampling = SamplingHandler(self.name, sampling_config)
|
|
else:
|
|
self._sampling = None
|
|
|
|
# Validate: warn if both url and command are present
|
|
if "url" in config and "command" in config:
|
|
logger.warning(
|
|
"MCP server '%s' has both 'url' and 'command' in config. "
|
|
"Using HTTP transport ('url'). Remove 'command' to silence "
|
|
"this warning.",
|
|
self.name,
|
|
)
|
|
retries = 0
|
|
initial_retries = 0
|
|
backoff = 1.0
|
|
|
|
while True:
|
|
try:
|
|
if self._is_http():
|
|
await self._run_http(config)
|
|
else:
|
|
await self._run_stdio(config)
|
|
# Transport returned cleanly. Two cases:
|
|
# - _shutdown_event was set: exit the run loop entirely.
|
|
# - _reconnect_event was set (auth recovery): loop back and
|
|
# rebuild the MCP session with fresh credentials. Do NOT
|
|
# touch the retry counters — this is not a failure.
|
|
if self._shutdown_event.is_set():
|
|
break
|
|
logger.info(
|
|
"MCP server '%s': reconnecting (OAuth recovery or "
|
|
"manual refresh)",
|
|
self.name,
|
|
)
|
|
# Reset the session reference; _run_http/_run_stdio will
|
|
# repopulate it on successful re-entry.
|
|
self.session = None
|
|
# Keep _ready set across reconnects so tool handlers can
|
|
# still detect a transient in-flight state — it'll be
|
|
# re-set after the fresh session initializes.
|
|
continue
|
|
except Exception as exc:
|
|
self.session = None
|
|
|
|
# If this is the first connection attempt, retry with backoff
|
|
# before giving up. A transient DNS/network blip at startup
|
|
# should not permanently kill the server.
|
|
# (Ported from Kilo Code's MCP resilience fix.)
|
|
if not self._ready.is_set():
|
|
initial_retries += 1
|
|
if initial_retries > _MAX_INITIAL_CONNECT_RETRIES:
|
|
logger.warning(
|
|
"MCP server '%s' failed initial connection after "
|
|
"%d attempts, giving up: %s",
|
|
self.name, _MAX_INITIAL_CONNECT_RETRIES, exc,
|
|
)
|
|
self._error = exc
|
|
self._ready.set()
|
|
return
|
|
|
|
logger.warning(
|
|
"MCP server '%s' initial connection failed "
|
|
"(attempt %d/%d), retrying in %.0fs: %s",
|
|
self.name, initial_retries,
|
|
_MAX_INITIAL_CONNECT_RETRIES, backoff, exc,
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
|
|
|
|
# Check if shutdown was requested during the sleep
|
|
if self._shutdown_event.is_set():
|
|
self._error = exc
|
|
self._ready.set()
|
|
return
|
|
continue
|
|
|
|
# If shutdown was requested, don't reconnect
|
|
if self._shutdown_event.is_set():
|
|
logger.debug(
|
|
"MCP server '%s' disconnected during shutdown: %s",
|
|
self.name, exc,
|
|
)
|
|
return
|
|
|
|
retries += 1
|
|
if retries > _MAX_RECONNECT_RETRIES:
|
|
logger.warning(
|
|
"MCP server '%s' failed after %d reconnection attempts, "
|
|
"giving up: %s",
|
|
self.name, _MAX_RECONNECT_RETRIES, exc,
|
|
)
|
|
return
|
|
|
|
logger.warning(
|
|
"MCP server '%s' connection lost (attempt %d/%d), "
|
|
"reconnecting in %.0fs: %s",
|
|
self.name, retries, _MAX_RECONNECT_RETRIES,
|
|
backoff, exc,
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
|
|
|
|
# Check again after sleeping
|
|
if self._shutdown_event.is_set():
|
|
return
|
|
finally:
|
|
self.session = None
|
|
|
|
async def start(self, config: dict):
|
|
"""Create the background Task and wait until ready (or failed)."""
|
|
self._task = asyncio.ensure_future(self.run(config))
|
|
await self._ready.wait()
|
|
if self._error:
|
|
raise self._error
|
|
|
|
async def shutdown(self):
|
|
"""Signal the Task to exit and wait for clean resource teardown."""
|
|
from tools.registry import registry
|
|
|
|
self._shutdown_event.set()
|
|
# Defensive: if _wait_for_lifecycle_event is blocking, we need ANY
|
|
# event to unblock it. _shutdown_event alone is sufficient (the
|
|
# helper checks shutdown first), but setting reconnect too ensures
|
|
# there's no race where the helper misses the shutdown flag after
|
|
# returning "reconnect".
|
|
self._reconnect_event.set()
|
|
if self._task and not self._task.done():
|
|
try:
|
|
await asyncio.wait_for(self._task, timeout=10)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(
|
|
"MCP server '%s' shutdown timed out, cancelling task",
|
|
self.name,
|
|
)
|
|
self._task.cancel()
|
|
try:
|
|
await self._task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
|
registry.deregister(tool_name)
|
|
self._registered_tool_names = []
|
|
self.session = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level state
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_servers: Dict[str, MCPServerTask] = {}
|
|
|
|
# Circuit breaker: consecutive error counts per server. After
|
|
# _CIRCUIT_BREAKER_THRESHOLD consecutive failures, the handler returns
|
|
# a "server unreachable" message that tells the model to stop retrying,
|
|
# preventing the 90-iteration burn loop described in #10447.
|
|
#
|
|
# State machine:
|
|
# closed — error count below threshold; all calls go through.
|
|
# open — threshold reached; calls short-circuit until the
|
|
# cooldown elapses.
|
|
# half-open — cooldown elapsed; the next call is a probe that
|
|
# actually hits the session. Probe success → closed.
|
|
# Probe failure → reopens (cooldown re-armed).
|
|
#
|
|
# ``_server_breaker_opened_at`` records the monotonic timestamp when
|
|
# the breaker most recently transitioned into the open state. Use the
|
|
# ``_bump_server_error`` / ``_reset_server_error`` helpers to mutate
|
|
# this state — they keep the count and timestamp in sync.
|
|
_server_error_counts: Dict[str, int] = {}
|
|
_server_breaker_opened_at: Dict[str, float] = {}
|
|
_CIRCUIT_BREAKER_THRESHOLD = 3
|
|
_CIRCUIT_BREAKER_COOLDOWN_SEC = 60.0
|
|
|
|
|
|
def _bump_server_error(server_name: str) -> None:
|
|
"""Increment the consecutive-failure count for ``server_name``.
|
|
|
|
When the count crosses :data:`_CIRCUIT_BREAKER_THRESHOLD`, stamp the
|
|
breaker-open timestamp so the cooldown clock starts (or re-starts,
|
|
for probe failures in the half-open state).
|
|
"""
|
|
n = _server_error_counts.get(server_name, 0) + 1
|
|
_server_error_counts[server_name] = n
|
|
if n >= _CIRCUIT_BREAKER_THRESHOLD:
|
|
_server_breaker_opened_at[server_name] = time.monotonic()
|
|
|
|
|
|
def _reset_server_error(server_name: str) -> None:
|
|
"""Fully close the breaker for ``server_name``.
|
|
|
|
Clears both the failure count and the breaker-open timestamp. Call
|
|
this on any unambiguous success signal (successful tool call,
|
|
successful reconnect, manual /mcp refresh).
|
|
"""
|
|
_server_error_counts[server_name] = 0
|
|
_server_breaker_opened_at.pop(server_name, None)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Cached tuple of auth-related exception types. Lazy so this module
|
|
# imports cleanly when the MCP SDK OAuth module is missing.
|
|
_AUTH_ERROR_TYPES: tuple = ()
|
|
|
|
|
|
def _get_auth_error_types() -> tuple:
|
|
"""Return a tuple of exception types that indicate MCP OAuth failure.
|
|
|
|
Cached after first call. Includes:
|
|
- ``mcp.client.auth.OAuthFlowError`` / ``OAuthTokenError`` — raised by
|
|
the SDK's auth flow when discovery, refresh, or full re-auth fails.
|
|
- ``mcp.client.auth.UnauthorizedError`` (older MCP SDKs) — kept as an
|
|
optional import for forward/backward compatibility.
|
|
- ``tools.mcp_oauth.OAuthNonInteractiveError`` — raised by our callback
|
|
handler when no user is present to complete a browser flow.
|
|
- ``httpx.HTTPStatusError`` — caller must additionally check
|
|
``status_code == 401`` via :func:`_is_auth_error`.
|
|
"""
|
|
global _AUTH_ERROR_TYPES
|
|
if _AUTH_ERROR_TYPES:
|
|
return _AUTH_ERROR_TYPES
|
|
types: list = []
|
|
try:
|
|
from mcp.client.auth import OAuthFlowError, OAuthTokenError
|
|
types.extend([OAuthFlowError, OAuthTokenError])
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
# Older MCP SDK variants exported this
|
|
from mcp.client.auth import UnauthorizedError # type: ignore
|
|
types.append(UnauthorizedError)
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from tools.mcp_oauth import OAuthNonInteractiveError
|
|
types.append(OAuthNonInteractiveError)
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
import httpx
|
|
types.append(httpx.HTTPStatusError)
|
|
except ImportError:
|
|
pass
|
|
_AUTH_ERROR_TYPES = tuple(types)
|
|
return _AUTH_ERROR_TYPES
|
|
|
|
|
|
def _is_auth_error(exc: BaseException) -> bool:
|
|
"""Return True if ``exc`` indicates an MCP OAuth failure.
|
|
|
|
``httpx.HTTPStatusError`` is only treated as auth-related when the
|
|
response status code is 401. Other HTTP errors fall through to the
|
|
generic error path in the tool handlers.
|
|
"""
|
|
types = _get_auth_error_types()
|
|
if not types or not isinstance(exc, types):
|
|
return False
|
|
try:
|
|
import httpx
|
|
if isinstance(exc, httpx.HTTPStatusError):
|
|
return getattr(exc.response, "status_code", None) == 401
|
|
except ImportError:
|
|
pass
|
|
return True
|
|
|
|
|
|
def _handle_auth_error_and_retry(
|
|
server_name: str,
|
|
exc: BaseException,
|
|
retry_call,
|
|
op_description: str,
|
|
):
|
|
"""Attempt auth recovery and one retry; return None to fall through.
|
|
|
|
Called by the 5 MCP tool handlers when ``session.<op>()`` raises an
|
|
auth-related exception. Workflow:
|
|
|
|
1. Ask :class:`tools.mcp_oauth_manager.MCPOAuthManager.handle_401` if
|
|
recovery is viable (i.e., disk has fresh tokens, or the SDK can
|
|
refresh in-place).
|
|
2. If yes, set the server's ``_reconnect_event`` so the server task
|
|
tears down the current MCP session and rebuilds it with fresh
|
|
credentials. Wait briefly for ``_ready`` to re-fire.
|
|
3. Retry the operation once. Return the retry result if it produced
|
|
a non-error JSON payload. Otherwise return the ``needs_reauth``
|
|
error dict so the model stops hallucinating manual refresh.
|
|
4. Return None if ``exc`` is not an auth error, signalling the
|
|
caller to use the generic error path.
|
|
|
|
Args:
|
|
server_name: Name of the MCP server that raised.
|
|
exc: The exception from the failed tool call.
|
|
retry_call: Zero-arg callable that re-runs the tool call, returning
|
|
the same JSON string format as the handler.
|
|
op_description: Human-readable name of the operation (for logs).
|
|
|
|
Returns:
|
|
A JSON string if auth recovery was attempted, or None to fall
|
|
through to the caller's generic error path.
|
|
"""
|
|
if not _is_auth_error(exc):
|
|
return None
|
|
|
|
from tools.mcp_oauth_manager import get_manager
|
|
manager = get_manager()
|
|
|
|
async def _recover():
|
|
return await manager.handle_401(server_name, None)
|
|
|
|
try:
|
|
recovered = _run_on_mcp_loop(_recover(), timeout=10)
|
|
except Exception as rec_exc:
|
|
logger.warning(
|
|
"MCP OAuth '%s': recovery attempt failed: %s",
|
|
server_name, rec_exc,
|
|
)
|
|
recovered = False
|
|
|
|
if recovered:
|
|
with _lock:
|
|
srv = _servers.get(server_name)
|
|
if srv is not None and hasattr(srv, "_reconnect_event"):
|
|
loop = _mcp_loop
|
|
if loop is not None and loop.is_running():
|
|
loop.call_soon_threadsafe(srv._reconnect_event.set)
|
|
# Wait briefly for the session to come back ready. Bounded
|
|
# so that a stuck reconnect falls through to the error
|
|
# path rather than hanging the caller.
|
|
deadline = time.monotonic() + 15
|
|
while time.monotonic() < deadline:
|
|
if srv.session is not None and srv._ready.is_set():
|
|
break
|
|
time.sleep(0.25)
|
|
|
|
# A successful OAuth recovery is independent evidence that the
|
|
# server is viable again, so close the circuit breaker here —
|
|
# not only on retry success. Without this, a reconnect
|
|
# followed by a failing retry would leave the breaker pinned
|
|
# above threshold forever (the retry-exception branch below
|
|
# bumps the count again). The post-reset retry still goes
|
|
# through _bump_server_error on failure, so a genuinely broken
|
|
# server will re-trip the breaker as normal.
|
|
_reset_server_error(server_name)
|
|
|
|
try:
|
|
result = retry_call()
|
|
try:
|
|
parsed = json.loads(result)
|
|
if "error" not in parsed:
|
|
_reset_server_error(server_name)
|
|
return result
|
|
except (json.JSONDecodeError, TypeError):
|
|
_reset_server_error(server_name)
|
|
return result
|
|
except Exception as retry_exc:
|
|
logger.warning(
|
|
"MCP %s/%s retry after auth recovery failed: %s",
|
|
server_name, op_description, retry_exc,
|
|
)
|
|
|
|
# No recovery available, or retry also failed: surface a structured
|
|
# needs_reauth error. Bumps the circuit breaker so the model stops
|
|
# retrying the tool.
|
|
_bump_server_error(server_name)
|
|
return json.dumps({
|
|
"error": (
|
|
f"MCP server '{server_name}' requires re-authentication. "
|
|
f"Run `hermes mcp login {server_name}` (or delete the tokens "
|
|
f"file under ~/.hermes/mcp-tokens/ and restart). Do NOT retry "
|
|
f"this tool — ask the user to re-authenticate."
|
|
),
|
|
"needs_reauth": True,
|
|
"server": server_name,
|
|
}, ensure_ascii=False)
|
|
|
|
# Dedicated event loop running in a background daemon thread.
|
|
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
_mcp_thread: Optional[threading.Thread] = None
|
|
|
|
# Protects _mcp_loop, _mcp_thread, _servers, and _stdio_pids.
|
|
_lock = threading.Lock()
|
|
|
|
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
|
|
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
|
|
# fails or times out. PIDs are added after connection and removed on
|
|
# normal server shutdown.
|
|
_stdio_pids: Dict[int, str] = {} # pid -> server_name
|
|
|
|
|
|
def _snapshot_child_pids() -> set:
|
|
"""Return a set of current child process PIDs.
|
|
|
|
Uses /proc on Linux, falls back to psutil, then empty set.
|
|
Used by _run_stdio to identify the subprocess spawned by stdio_client.
|
|
"""
|
|
my_pid = os.getpid()
|
|
|
|
# Linux: read from /proc
|
|
try:
|
|
children_path = f"/proc/{my_pid}/task/{my_pid}/children"
|
|
with open(children_path) as f:
|
|
return {int(p) for p in f.read().split() if p.strip()}
|
|
except (FileNotFoundError, OSError, ValueError):
|
|
pass
|
|
|
|
# Fallback: psutil
|
|
try:
|
|
import psutil
|
|
return {c.pid for c in psutil.Process(my_pid).children()}
|
|
except Exception:
|
|
pass
|
|
|
|
return set()
|
|
|
|
|
|
def _mcp_loop_exception_handler(loop, context):
|
|
"""Suppress benign 'Event loop is closed' noise during shutdown.
|
|
|
|
When the MCP event loop is stopped and closed, httpx/httpcore async
|
|
transports may fire __del__ finalizers that call call_soon() on the
|
|
dead loop. asyncio catches that RuntimeError and routes it here.
|
|
We silence it because the connection is being torn down anyway; all
|
|
other exceptions are forwarded to the default handler.
|
|
"""
|
|
exc = context.get("exception")
|
|
if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc):
|
|
return # benign shutdown race — suppress
|
|
loop.default_exception_handler(context)
|
|
|
|
|
|
def _ensure_mcp_loop():
|
|
"""Start the background event loop thread if not already running."""
|
|
global _mcp_loop, _mcp_thread
|
|
with _lock:
|
|
if _mcp_loop is not None and _mcp_loop.is_running():
|
|
return
|
|
_mcp_loop = asyncio.new_event_loop()
|
|
_mcp_loop.set_exception_handler(_mcp_loop_exception_handler)
|
|
_mcp_thread = threading.Thread(
|
|
target=_mcp_loop.run_forever,
|
|
name="mcp-event-loop",
|
|
daemon=True,
|
|
)
|
|
_mcp_thread.start()
|
|
|
|
|
|
def _run_on_mcp_loop(coro, timeout: float = 30):
|
|
"""Schedule a coroutine on the MCP event loop and block until done.
|
|
|
|
Poll in short intervals so the calling agent thread can honor user
|
|
interrupts while the MCP work is still running on the background loop.
|
|
"""
|
|
from tools.interrupt import is_interrupted
|
|
|
|
with _lock:
|
|
loop = _mcp_loop
|
|
if loop is None or not loop.is_running():
|
|
raise RuntimeError("MCP event loop is not running")
|
|
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
deadline = None if timeout is None else time.monotonic() + timeout
|
|
|
|
while True:
|
|
if is_interrupted():
|
|
future.cancel()
|
|
raise InterruptedError("User sent a new message")
|
|
|
|
wait_timeout = 0.1
|
|
if deadline is not None:
|
|
remaining = deadline - time.monotonic()
|
|
if remaining <= 0:
|
|
return future.result(timeout=0)
|
|
wait_timeout = min(wait_timeout, remaining)
|
|
|
|
try:
|
|
return future.result(timeout=wait_timeout)
|
|
except concurrent.futures.TimeoutError:
|
|
continue
|
|
|
|
|
|
def _interrupted_call_result() -> str:
|
|
"""Standardized JSON error for a user-interrupted MCP tool call."""
|
|
return json.dumps({
|
|
"error": "MCP call interrupted: user sent a new message"
|
|
}, ensure_ascii=False)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _interpolate_env_vars(value):
|
|
"""Recursively resolve ``${VAR}`` placeholders from ``os.environ``."""
|
|
if isinstance(value, str):
|
|
def _replace(m):
|
|
return os.environ.get(m.group(1), m.group(0))
|
|
return re.sub(r"\$\{([^}]+)\}", _replace, value)
|
|
if isinstance(value, dict):
|
|
return {k: _interpolate_env_vars(v) for k, v in value.items()}
|
|
if isinstance(value, list):
|
|
return [_interpolate_env_vars(v) for v in value]
|
|
return value
|
|
|
|
|
|
def _load_mcp_config() -> Dict[str, dict]:
|
|
"""Read ``mcp_servers`` from the Hermes config file.
|
|
|
|
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
|
Server config can contain either ``command``/``args``/``env`` for stdio
|
|
transport or ``url``/``headers`` for HTTP transport, plus optional
|
|
``timeout``, ``connect_timeout``, and ``auth`` overrides.
|
|
|
|
``${ENV_VAR}`` placeholders in string values are resolved from
|
|
``os.environ`` (which includes ``~/.hermes/.env`` loaded at startup).
|
|
"""
|
|
try:
|
|
from hermes_cli.config import load_config
|
|
config = load_config()
|
|
servers = config.get("mcp_servers")
|
|
if not servers or not isinstance(servers, dict):
|
|
return {}
|
|
# Ensure .env vars are available for interpolation
|
|
try:
|
|
from hermes_cli.env_loader import load_hermes_dotenv
|
|
load_hermes_dotenv()
|
|
except Exception:
|
|
pass
|
|
return {name: _interpolate_env_vars(cfg) for name, cfg in servers.items()}
|
|
except Exception as exc:
|
|
logger.debug("Failed to load MCP config: %s", exc)
|
|
return {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Server connection helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
|
"""Create an MCPServerTask, start it, and return when ready.
|
|
|
|
The server Task keeps the connection alive in the background.
|
|
Call ``server.shutdown()`` (on the same event loop) to tear it down.
|
|
|
|
Raises:
|
|
ValueError: if required config keys are missing.
|
|
ImportError: if HTTP transport is needed but not available.
|
|
Exception: on connection or initialization failure.
|
|
"""
|
|
server = MCPServerTask(name)
|
|
await server.start(config)
|
|
return server
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Handler / check-fn factories
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
|
"""Return a sync handler that calls an MCP tool via the background loop.
|
|
|
|
The handler conforms to the registry's dispatch interface:
|
|
``handler(args_dict, **kwargs) -> str``
|
|
"""
|
|
|
|
def _handler(args: dict, **kwargs) -> str:
|
|
# Circuit breaker: if this server has failed too many times
|
|
# consecutively, short-circuit with a clear message so the model
|
|
# stops retrying and uses alternative approaches (#10447).
|
|
#
|
|
# Once the cooldown elapses, the breaker transitions to
|
|
# half-open: we let the *next* call through as a probe. On
|
|
# success the success-path below resets the breaker; on
|
|
# failure the error paths below bump the count again, which
|
|
# re-stamps the open-time via _bump_server_error (re-arming
|
|
# the cooldown).
|
|
if _server_error_counts.get(server_name, 0) >= _CIRCUIT_BREAKER_THRESHOLD:
|
|
opened_at = _server_breaker_opened_at.get(server_name, 0.0)
|
|
age = time.monotonic() - opened_at
|
|
if age < _CIRCUIT_BREAKER_COOLDOWN_SEC:
|
|
remaining = max(1, int(_CIRCUIT_BREAKER_COOLDOWN_SEC - age))
|
|
return json.dumps({
|
|
"error": (
|
|
f"MCP server '{server_name}' is unreachable after "
|
|
f"{_server_error_counts[server_name]} consecutive "
|
|
f"failures. Auto-retry available in ~{remaining}s. "
|
|
f"Do NOT retry this tool yet — use alternative "
|
|
f"approaches or ask the user to check the MCP server."
|
|
)
|
|
}, ensure_ascii=False)
|
|
# Cooldown elapsed → fall through as a half-open probe.
|
|
|
|
with _lock:
|
|
server = _servers.get(server_name)
|
|
if not server or not server.session:
|
|
_bump_server_error(server_name)
|
|
return json.dumps({
|
|
"error": f"MCP server '{server_name}' is not connected"
|
|
}, ensure_ascii=False)
|
|
|
|
async def _call():
|
|
result = await server.session.call_tool(tool_name, arguments=args)
|
|
# MCP CallToolResult has .content (list of content blocks) and .isError
|
|
if result.isError:
|
|
error_text = ""
|
|
for block in (result.content or []):
|
|
if hasattr(block, "text"):
|
|
error_text += block.text
|
|
return json.dumps({
|
|
"error": _sanitize_error(
|
|
error_text or "MCP tool returned an error"
|
|
)
|
|
}, ensure_ascii=False)
|
|
|
|
# Collect text from content blocks
|
|
parts: List[str] = []
|
|
for block in (result.content or []):
|
|
if hasattr(block, "text"):
|
|
parts.append(block.text)
|
|
text_result = "\n".join(parts) if parts else ""
|
|
|
|
# Combine content + structuredContent when both are present.
|
|
# MCP spec: content is model-oriented (text), structuredContent
|
|
# is machine-oriented (JSON metadata). For an AI agent, content
|
|
# is the primary payload; structuredContent supplements it.
|
|
structured = getattr(result, "structuredContent", None)
|
|
if structured is not None:
|
|
if text_result:
|
|
return json.dumps({
|
|
"result": text_result,
|
|
"structuredContent": structured,
|
|
}, ensure_ascii=False)
|
|
return json.dumps({"result": structured}, ensure_ascii=False)
|
|
return json.dumps({"result": text_result}, ensure_ascii=False)
|
|
|
|
def _call_once():
|
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
|
|
|
try:
|
|
result = _call_once()
|
|
# Check if the MCP tool itself returned an error
|
|
try:
|
|
parsed = json.loads(result)
|
|
if "error" in parsed:
|
|
_bump_server_error(server_name)
|
|
else:
|
|
_reset_server_error(server_name) # success — reset
|
|
except (json.JSONDecodeError, TypeError):
|
|
_reset_server_error(server_name) # non-JSON = success
|
|
return result
|
|
except InterruptedError:
|
|
return _interrupted_call_result()
|
|
except Exception as exc:
|
|
# Auth-specific recovery path: consult the manager, signal
|
|
# reconnect if viable, retry once. Returns None to fall
|
|
# through for non-auth exceptions.
|
|
recovered = _handle_auth_error_and_retry(
|
|
server_name, exc, _call_once,
|
|
f"tools/call {tool_name}",
|
|
)
|
|
if recovered is not None:
|
|
return recovered
|
|
|
|
_bump_server_error(server_name)
|
|
logger.error(
|
|
"MCP tool %s/%s call failed: %s",
|
|
server_name, tool_name, exc,
|
|
)
|
|
return json.dumps({
|
|
"error": _sanitize_error(
|
|
f"MCP call failed: {type(exc).__name__}: {exc}"
|
|
)
|
|
}, ensure_ascii=False)
|
|
|
|
return _handler
|
|
|
|
|
|
def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
|
"""Return a sync handler that lists resources from an MCP server."""
|
|
|
|
def _handler(args: dict, **kwargs) -> str:
|
|
with _lock:
|
|
server = _servers.get(server_name)
|
|
if not server or not server.session:
|
|
return json.dumps({
|
|
"error": f"MCP server '{server_name}' is not connected"
|
|
}, ensure_ascii=False)
|
|
|
|
async def _call():
|
|
result = await server.session.list_resources()
|
|
resources = []
|
|
for r in (result.resources if hasattr(result, "resources") else []):
|
|
entry = {}
|
|
if hasattr(r, "uri"):
|
|
entry["uri"] = str(r.uri)
|
|
if hasattr(r, "name"):
|
|
entry["name"] = r.name
|
|
if hasattr(r, "description") and r.description:
|
|
entry["description"] = r.description
|
|
if hasattr(r, "mimeType") and r.mimeType:
|
|
entry["mimeType"] = r.mimeType
|
|
resources.append(entry)
|
|
return json.dumps({"resources": resources}, ensure_ascii=False)
|
|
|
|
def _call_once():
|
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
|
|
|
try:
|
|
return _call_once()
|
|
except InterruptedError:
|
|
return _interrupted_call_result()
|
|
except Exception as exc:
|
|
recovered = _handle_auth_error_and_retry(
|
|
server_name, exc, _call_once, "resources/list",
|
|
)
|
|
if recovered is not None:
|
|
return recovered
|
|
logger.error(
|
|
"MCP %s/list_resources failed: %s", server_name, exc,
|
|
)
|
|
return json.dumps({
|
|
"error": _sanitize_error(
|
|
f"MCP call failed: {type(exc).__name__}: {exc}"
|
|
)
|
|
}, ensure_ascii=False)
|
|
|
|
return _handler
|
|
|
|
|
|
def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
|
"""Return a sync handler that reads a resource by URI from an MCP server."""
|
|
|
|
def _handler(args: dict, **kwargs) -> str:
|
|
from tools.registry import tool_error
|
|
|
|
with _lock:
|
|
server = _servers.get(server_name)
|
|
if not server or not server.session:
|
|
return json.dumps({
|
|
"error": f"MCP server '{server_name}' is not connected"
|
|
}, ensure_ascii=False)
|
|
|
|
uri = args.get("uri")
|
|
if not uri:
|
|
return tool_error("Missing required parameter 'uri'")
|
|
|
|
async def _call():
|
|
result = await server.session.read_resource(uri)
|
|
# read_resource returns ReadResourceResult with .contents list
|
|
parts: List[str] = []
|
|
contents = result.contents if hasattr(result, "contents") else []
|
|
for block in contents:
|
|
if hasattr(block, "text"):
|
|
parts.append(block.text)
|
|
elif hasattr(block, "blob"):
|
|
parts.append(f"[binary data, {len(block.blob)} bytes]")
|
|
return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False)
|
|
|
|
def _call_once():
|
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
|
|
|
try:
|
|
return _call_once()
|
|
except InterruptedError:
|
|
return _interrupted_call_result()
|
|
except Exception as exc:
|
|
recovered = _handle_auth_error_and_retry(
|
|
server_name, exc, _call_once, "resources/read",
|
|
)
|
|
if recovered is not None:
|
|
return recovered
|
|
logger.error(
|
|
"MCP %s/read_resource failed: %s", server_name, exc,
|
|
)
|
|
return json.dumps({
|
|
"error": _sanitize_error(
|
|
f"MCP call failed: {type(exc).__name__}: {exc}"
|
|
)
|
|
}, ensure_ascii=False)
|
|
|
|
return _handler
|
|
|
|
|
|
def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
|
"""Return a sync handler that lists prompts from an MCP server."""
|
|
|
|
def _handler(args: dict, **kwargs) -> str:
|
|
with _lock:
|
|
server = _servers.get(server_name)
|
|
if not server or not server.session:
|
|
return json.dumps({
|
|
"error": f"MCP server '{server_name}' is not connected"
|
|
}, ensure_ascii=False)
|
|
|
|
async def _call():
|
|
result = await server.session.list_prompts()
|
|
prompts = []
|
|
for p in (result.prompts if hasattr(result, "prompts") else []):
|
|
entry = {}
|
|
if hasattr(p, "name"):
|
|
entry["name"] = p.name
|
|
if hasattr(p, "description") and p.description:
|
|
entry["description"] = p.description
|
|
if hasattr(p, "arguments") and p.arguments:
|
|
entry["arguments"] = [
|
|
{
|
|
"name": a.name,
|
|
**({"description": a.description} if hasattr(a, "description") and a.description else {}),
|
|
**({"required": a.required} if hasattr(a, "required") else {}),
|
|
}
|
|
for a in p.arguments
|
|
]
|
|
prompts.append(entry)
|
|
return json.dumps({"prompts": prompts}, ensure_ascii=False)
|
|
|
|
def _call_once():
|
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
|
|
|
try:
|
|
return _call_once()
|
|
except InterruptedError:
|
|
return _interrupted_call_result()
|
|
except Exception as exc:
|
|
recovered = _handle_auth_error_and_retry(
|
|
server_name, exc, _call_once, "prompts/list",
|
|
)
|
|
if recovered is not None:
|
|
return recovered
|
|
logger.error(
|
|
"MCP %s/list_prompts failed: %s", server_name, exc,
|
|
)
|
|
return json.dumps({
|
|
"error": _sanitize_error(
|
|
f"MCP call failed: {type(exc).__name__}: {exc}"
|
|
)
|
|
}, ensure_ascii=False)
|
|
|
|
return _handler
|
|
|
|
|
|
def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
|
"""Return a sync handler that gets a prompt by name from an MCP server."""
|
|
|
|
def _handler(args: dict, **kwargs) -> str:
|
|
from tools.registry import tool_error
|
|
|
|
with _lock:
|
|
server = _servers.get(server_name)
|
|
if not server or not server.session:
|
|
return json.dumps({
|
|
"error": f"MCP server '{server_name}' is not connected"
|
|
}, ensure_ascii=False)
|
|
|
|
name = args.get("name")
|
|
if not name:
|
|
return tool_error("Missing required parameter 'name'")
|
|
arguments = args.get("arguments", {})
|
|
|
|
async def _call():
|
|
result = await server.session.get_prompt(name, arguments=arguments)
|
|
# GetPromptResult has .messages list
|
|
messages = []
|
|
for msg in (result.messages if hasattr(result, "messages") else []):
|
|
entry = {}
|
|
if hasattr(msg, "role"):
|
|
entry["role"] = msg.role
|
|
if hasattr(msg, "content"):
|
|
content = msg.content
|
|
if hasattr(content, "text"):
|
|
entry["content"] = content.text
|
|
elif isinstance(content, str):
|
|
entry["content"] = content
|
|
else:
|
|
entry["content"] = str(content)
|
|
messages.append(entry)
|
|
resp = {"messages": messages}
|
|
if hasattr(result, "description") and result.description:
|
|
resp["description"] = result.description
|
|
return json.dumps(resp, ensure_ascii=False)
|
|
|
|
def _call_once():
|
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
|
|
|
try:
|
|
return _call_once()
|
|
except InterruptedError:
|
|
return _interrupted_call_result()
|
|
except Exception as exc:
|
|
recovered = _handle_auth_error_and_retry(
|
|
server_name, exc, _call_once, "prompts/get",
|
|
)
|
|
if recovered is not None:
|
|
return recovered
|
|
logger.error(
|
|
"MCP %s/get_prompt failed: %s", server_name, exc,
|
|
)
|
|
return json.dumps({
|
|
"error": _sanitize_error(
|
|
f"MCP call failed: {type(exc).__name__}: {exc}"
|
|
)
|
|
}, ensure_ascii=False)
|
|
|
|
return _handler
|
|
|
|
|
|
def _make_check_fn(server_name: str):
|
|
"""Return a check function that verifies the MCP connection is alive."""
|
|
|
|
def _check() -> bool:
|
|
with _lock:
|
|
server = _servers.get(server_name)
|
|
return server is not None and server.session is not None
|
|
|
|
return _check
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Discovery & registration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
|
"""Normalize MCP input schemas for LLM tool-calling compatibility.
|
|
|
|
MCP servers can emit plain JSON Schema with ``definitions`` /
|
|
``#/definitions/...`` references. Kimi / Moonshot rejects that form and
|
|
requires local refs to point into ``#/$defs/...`` instead. Normalize the
|
|
common draft-07 shape here so MCP tool schemas remain portable across
|
|
OpenAI-compatible providers.
|
|
|
|
Additional MCP-server robustness repairs applied recursively:
|
|
|
|
* Missing or ``null`` ``type`` on an object-shaped node is coerced to
|
|
``"object"`` (some servers omit it). See PR #4897.
|
|
* When an ``object`` node lacks ``properties``, an empty ``properties``
|
|
dict is added so ``required`` entries don't dangle.
|
|
* ``required`` arrays are pruned to only names that exist in
|
|
``properties``; otherwise Google AI Studio / Gemini 400s with
|
|
``property is not defined``. See PR #4651.
|
|
|
|
All repairs are provider-agnostic and ideally produce a schema valid on
|
|
OpenAI, Anthropic, Gemini, and Moonshot in one pass.
|
|
"""
|
|
if not schema:
|
|
return {"type": "object", "properties": {}}
|
|
|
|
def _rewrite_local_refs(node):
|
|
if isinstance(node, dict):
|
|
normalized = {}
|
|
for key, value in node.items():
|
|
out_key = "$defs" if key == "definitions" else key
|
|
normalized[out_key] = _rewrite_local_refs(value)
|
|
ref = normalized.get("$ref")
|
|
if isinstance(ref, str) and ref.startswith("#/definitions/"):
|
|
normalized["$ref"] = "#/$defs/" + ref[len("#/definitions/"):]
|
|
return normalized
|
|
if isinstance(node, list):
|
|
return [_rewrite_local_refs(item) for item in node]
|
|
return node
|
|
|
|
def _repair_object_shape(node):
|
|
"""Recursively repair object-shaped nodes: fill type, prune required."""
|
|
if isinstance(node, list):
|
|
return [_repair_object_shape(item) for item in node]
|
|
if not isinstance(node, dict):
|
|
return node
|
|
|
|
repaired = {k: _repair_object_shape(v) for k, v in node.items()}
|
|
|
|
# Coerce missing / null type when the shape is clearly an object
|
|
# (has properties or required but no type).
|
|
if not repaired.get("type") and (
|
|
"properties" in repaired or "required" in repaired
|
|
):
|
|
repaired["type"] = "object"
|
|
|
|
if repaired.get("type") == "object":
|
|
# Ensure properties exists so required can reference it safely
|
|
if "properties" not in repaired or not isinstance(
|
|
repaired.get("properties"), dict
|
|
):
|
|
repaired["properties"] = {} if "properties" not in repaired else repaired["properties"]
|
|
if not isinstance(repaired.get("properties"), dict):
|
|
repaired["properties"] = {}
|
|
|
|
# Prune required to only include names that exist in properties
|
|
required = repaired.get("required")
|
|
if isinstance(required, list):
|
|
props = repaired.get("properties") or {}
|
|
valid = [r for r in required if isinstance(r, str) and r in props]
|
|
if len(valid) != len(required):
|
|
if valid:
|
|
repaired["required"] = valid
|
|
else:
|
|
repaired.pop("required", None)
|
|
|
|
return repaired
|
|
|
|
normalized = _rewrite_local_refs(schema)
|
|
normalized = _repair_object_shape(normalized)
|
|
|
|
# Ensure top-level is a well-formed object schema
|
|
if not isinstance(normalized, dict):
|
|
return {"type": "object", "properties": {}}
|
|
if normalized.get("type") == "object" and "properties" not in normalized:
|
|
normalized = {**normalized, "properties": {}}
|
|
|
|
return normalized
|
|
|
|
|
|
def sanitize_mcp_name_component(value: str) -> str:
|
|
"""Return an MCP name component safe for tool and prefix generation.
|
|
|
|
Preserves Hermes's historical behavior of converting hyphens to
|
|
underscores, and also replaces any other character outside
|
|
``[A-Za-z0-9_]`` with ``_`` so generated tool names are compatible with
|
|
provider validation rules.
|
|
"""
|
|
return re.sub(r"[^A-Za-z0-9_]", "_", str(value or ""))
|
|
|
|
|
|
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
|
"""Convert an MCP tool listing to the Hermes registry schema format.
|
|
|
|
Args:
|
|
server_name: The logical server name for prefixing.
|
|
mcp_tool: An MCP ``Tool`` object with ``.name``, ``.description``,
|
|
and ``.inputSchema``.
|
|
|
|
Returns:
|
|
A dict suitable for ``registry.register(schema=...)``.
|
|
"""
|
|
safe_tool_name = sanitize_mcp_name_component(mcp_tool.name)
|
|
safe_server_name = sanitize_mcp_name_component(server_name)
|
|
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
|
|
return {
|
|
"name": prefixed_name,
|
|
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
|
|
"parameters": _normalize_mcp_input_schema(getattr(mcp_tool, "inputSchema", None)),
|
|
}
|
|
|
|
|
|
def _build_utility_schemas(server_name: str) -> List[dict]:
|
|
"""Build schemas for the MCP utility tools (resources & prompts).
|
|
|
|
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
|
with keys: schema, handler_key.
|
|
"""
|
|
safe_name = sanitize_mcp_name_component(server_name)
|
|
return [
|
|
{
|
|
"schema": {
|
|
"name": f"mcp_{safe_name}_list_resources",
|
|
"description": f"List available resources from MCP server '{server_name}'",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
},
|
|
},
|
|
"handler_key": "list_resources",
|
|
},
|
|
{
|
|
"schema": {
|
|
"name": f"mcp_{safe_name}_read_resource",
|
|
"description": f"Read a resource by URI from MCP server '{server_name}'",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"uri": {
|
|
"type": "string",
|
|
"description": "URI of the resource to read",
|
|
},
|
|
},
|
|
"required": ["uri"],
|
|
},
|
|
},
|
|
"handler_key": "read_resource",
|
|
},
|
|
{
|
|
"schema": {
|
|
"name": f"mcp_{safe_name}_list_prompts",
|
|
"description": f"List available prompts from MCP server '{server_name}'",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
},
|
|
},
|
|
"handler_key": "list_prompts",
|
|
},
|
|
{
|
|
"schema": {
|
|
"name": f"mcp_{safe_name}_get_prompt",
|
|
"description": f"Get a prompt by name from MCP server '{server_name}'",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string",
|
|
"description": "Name of the prompt to retrieve",
|
|
},
|
|
"arguments": {
|
|
"type": "object",
|
|
"description": "Optional arguments to pass to the prompt",
|
|
},
|
|
},
|
|
"required": ["name"],
|
|
},
|
|
},
|
|
"handler_key": "get_prompt",
|
|
},
|
|
]
|
|
|
|
|
|
def _normalize_name_filter(value: Any, label: str) -> set[str]:
|
|
"""Normalize include/exclude config to a set of tool names."""
|
|
if value is None:
|
|
return set()
|
|
if isinstance(value, str):
|
|
return {value}
|
|
if isinstance(value, (list, tuple, set)):
|
|
return {str(item) for item in value}
|
|
logger.warning("MCP config %s must be a string or list of strings; ignoring %r", label, value)
|
|
return set()
|
|
|
|
|
|
def _parse_boolish(value: Any, default: bool = True) -> bool:
|
|
"""Parse a bool-like config value with safe fallback."""
|
|
if value is None:
|
|
return default
|
|
if isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, str):
|
|
lowered = value.strip().lower()
|
|
if lowered in {"true", "1", "yes", "on"}:
|
|
return True
|
|
if lowered in {"false", "0", "no", "off"}:
|
|
return False
|
|
logger.warning("MCP config expected a boolean-ish value, got %r; using default=%s", value, default)
|
|
return default
|
|
|
|
|
|
_UTILITY_CAPABILITY_METHODS = {
|
|
"list_resources": "list_resources",
|
|
"read_resource": "read_resource",
|
|
"list_prompts": "list_prompts",
|
|
"get_prompt": "get_prompt",
|
|
}
|
|
|
|
|
|
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
|
|
"""Select utility schemas based on config and server capabilities."""
|
|
tools_filter = config.get("tools") or {}
|
|
resources_enabled = _parse_boolish(tools_filter.get("resources"), default=True)
|
|
prompts_enabled = _parse_boolish(tools_filter.get("prompts"), default=True)
|
|
|
|
selected: List[dict] = []
|
|
for entry in _build_utility_schemas(server_name):
|
|
handler_key = entry["handler_key"]
|
|
if handler_key in {"list_resources", "read_resource"} and not resources_enabled:
|
|
logger.debug("MCP server '%s': skipping utility '%s' (resources disabled)", server_name, handler_key)
|
|
continue
|
|
if handler_key in {"list_prompts", "get_prompt"} and not prompts_enabled:
|
|
logger.debug("MCP server '%s': skipping utility '%s' (prompts disabled)", server_name, handler_key)
|
|
continue
|
|
|
|
required_method = _UTILITY_CAPABILITY_METHODS[handler_key]
|
|
if not hasattr(server.session, required_method):
|
|
logger.debug(
|
|
"MCP server '%s': skipping utility '%s' (session lacks %s)",
|
|
server_name,
|
|
handler_key,
|
|
required_method,
|
|
)
|
|
continue
|
|
selected.append(entry)
|
|
return selected
|
|
|
|
|
|
def _existing_tool_names() -> List[str]:
|
|
"""Return tool names for all currently connected servers."""
|
|
names: List[str] = []
|
|
for _sname, server in _servers.items():
|
|
if hasattr(server, "_registered_tool_names"):
|
|
names.extend(server._registered_tool_names)
|
|
continue
|
|
for mcp_tool in server._tools:
|
|
schema = _convert_mcp_schema(server.name, mcp_tool)
|
|
names.append(schema["name"])
|
|
return names
|
|
|
|
|
|
def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]:
|
|
"""Register tools from an already-connected server into the registry.
|
|
|
|
Handles include/exclude filtering and utility tools. Toolset resolution
|
|
for ``mcp-{server}`` and raw server-name aliases is derived from the live
|
|
registry, rather than mutating ``toolsets.TOOLSETS`` at runtime.
|
|
|
|
Used by both initial discovery and dynamic refresh (list_changed).
|
|
|
|
Returns:
|
|
List of registered prefixed tool names.
|
|
"""
|
|
from tools.registry import registry
|
|
|
|
registered_names: List[str] = []
|
|
toolset_name = f"mcp-{name}"
|
|
|
|
# Selective tool loading: honour include/exclude lists from config.
|
|
# Rules (matching issue #690 spec):
|
|
# tools.include — whitelist: only these tool names are registered
|
|
# tools.exclude — blacklist: all tools EXCEPT these are registered
|
|
# include takes precedence over exclude
|
|
# Neither set → register all tools (backward-compatible default)
|
|
tools_filter = config.get("tools") or {}
|
|
include_set = _normalize_name_filter(tools_filter.get("include"), f"mcp_servers.{name}.tools.include")
|
|
exclude_set = _normalize_name_filter(tools_filter.get("exclude"), f"mcp_servers.{name}.tools.exclude")
|
|
|
|
def _should_register(tool_name: str) -> bool:
|
|
if include_set:
|
|
return tool_name in include_set
|
|
if exclude_set:
|
|
return tool_name not in exclude_set
|
|
return True
|
|
|
|
for mcp_tool in server._tools:
|
|
if not _should_register(mcp_tool.name):
|
|
logger.debug("MCP server '%s': skipping tool '%s' (filtered by config)", name, mcp_tool.name)
|
|
continue
|
|
|
|
# Scan tool description for prompt injection patterns
|
|
_scan_mcp_description(name, mcp_tool.name, mcp_tool.description or "")
|
|
|
|
schema = _convert_mcp_schema(name, mcp_tool)
|
|
tool_name_prefixed = schema["name"]
|
|
|
|
# Guard against collisions with built-in (non-MCP) tools.
|
|
existing_toolset = registry.get_toolset_for_tool(tool_name_prefixed)
|
|
if existing_toolset and not existing_toolset.startswith("mcp-"):
|
|
logger.warning(
|
|
"MCP server '%s': tool '%s' (→ '%s') collides with built-in "
|
|
"tool in toolset '%s' — skipping to preserve built-in",
|
|
name, mcp_tool.name, tool_name_prefixed, existing_toolset,
|
|
)
|
|
continue
|
|
|
|
registry.register(
|
|
name=tool_name_prefixed,
|
|
toolset=toolset_name,
|
|
schema=schema,
|
|
handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout),
|
|
check_fn=_make_check_fn(name),
|
|
is_async=False,
|
|
description=schema["description"],
|
|
)
|
|
registered_names.append(tool_name_prefixed)
|
|
|
|
# Register MCP Resources & Prompts utility tools, filtered by config and
|
|
# only when the server actually supports the corresponding capability.
|
|
_handler_factories = {
|
|
"list_resources": _make_list_resources_handler,
|
|
"read_resource": _make_read_resource_handler,
|
|
"list_prompts": _make_list_prompts_handler,
|
|
"get_prompt": _make_get_prompt_handler,
|
|
}
|
|
check_fn = _make_check_fn(name)
|
|
for entry in _select_utility_schemas(name, server, config):
|
|
schema = entry["schema"]
|
|
handler_key = entry["handler_key"]
|
|
handler = _handler_factories[handler_key](name, server.tool_timeout)
|
|
util_name = schema["name"]
|
|
|
|
# Same collision guard for utility tools.
|
|
existing_toolset = registry.get_toolset_for_tool(util_name)
|
|
if existing_toolset and not existing_toolset.startswith("mcp-"):
|
|
logger.warning(
|
|
"MCP server '%s': utility tool '%s' collides with built-in "
|
|
"tool in toolset '%s' — skipping to preserve built-in",
|
|
name, util_name, existing_toolset,
|
|
)
|
|
continue
|
|
|
|
registry.register(
|
|
name=util_name,
|
|
toolset=toolset_name,
|
|
schema=schema,
|
|
handler=handler,
|
|
check_fn=check_fn,
|
|
is_async=False,
|
|
description=schema["description"],
|
|
)
|
|
registered_names.append(util_name)
|
|
|
|
if registered_names:
|
|
registry.register_toolset_alias(name, toolset_name)
|
|
|
|
return registered_names
|
|
|
|
|
|
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
|
"""Connect to a single MCP server, discover tools, and register them.
|
|
|
|
Returns list of registered tool names.
|
|
"""
|
|
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
|
server = await asyncio.wait_for(
|
|
_connect_server(name, config),
|
|
timeout=connect_timeout,
|
|
)
|
|
with _lock:
|
|
_servers[name] = server
|
|
|
|
registered_names = _register_server_tools(name, server, config)
|
|
server._registered_tool_names = list(registered_names)
|
|
|
|
transport_type = "HTTP" if "url" in config else "stdio"
|
|
logger.info(
|
|
"MCP server '%s' (%s): registered %d tool(s): %s",
|
|
name, transport_type, len(registered_names),
|
|
", ".join(registered_names),
|
|
)
|
|
return registered_names
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
|
"""Connect to explicit MCP servers and register their tools.
|
|
|
|
Idempotent for already-connected server names. Servers with
|
|
``enabled: false`` are skipped without disconnecting existing sessions.
|
|
|
|
Args:
|
|
servers: Mapping of ``{server_name: server_config}``.
|
|
|
|
Returns:
|
|
List of all currently registered MCP tool names.
|
|
"""
|
|
if not _MCP_AVAILABLE:
|
|
logger.debug("MCP SDK not available -- skipping explicit MCP registration")
|
|
return []
|
|
|
|
if not servers:
|
|
logger.debug("No explicit MCP servers provided")
|
|
return []
|
|
|
|
# Only attempt servers that aren't already connected and are enabled
|
|
# (enabled: false skips the server entirely without removing its config)
|
|
with _lock:
|
|
new_servers = {
|
|
k: v
|
|
for k, v in servers.items()
|
|
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
|
}
|
|
|
|
if not new_servers:
|
|
return _existing_tool_names()
|
|
|
|
# Start the background event loop for MCP connections
|
|
_ensure_mcp_loop()
|
|
|
|
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
|
"""Connect to a single server and return its registered tool names."""
|
|
return await _discover_and_register_server(name, cfg)
|
|
|
|
async def _discover_all():
|
|
server_names = list(new_servers.keys())
|
|
# Connect to all servers in PARALLEL
|
|
results = await asyncio.gather(
|
|
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
|
return_exceptions=True,
|
|
)
|
|
for name, result in zip(server_names, results):
|
|
if isinstance(result, Exception):
|
|
command = new_servers.get(name, {}).get("command")
|
|
logger.warning(
|
|
"Failed to connect to MCP server '%s'%s: %s",
|
|
name,
|
|
f" (command={command})" if command else "",
|
|
_format_connect_error(result),
|
|
)
|
|
|
|
# Per-server timeouts are handled inside _discover_and_register_server.
|
|
# The outer timeout is generous: 120s total for parallel discovery.
|
|
_run_on_mcp_loop(_discover_all(), timeout=120)
|
|
|
|
# Log a summary so ACP callers get visibility into what was registered.
|
|
with _lock:
|
|
connected = [n for n in new_servers if n in _servers]
|
|
new_tool_count = sum(
|
|
len(getattr(_servers[n], "_registered_tool_names", []))
|
|
for n in connected
|
|
)
|
|
failed = len(new_servers) - len(connected)
|
|
if new_tool_count or failed:
|
|
summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)"
|
|
if failed:
|
|
summary += f" ({failed} failed)"
|
|
logger.info(summary)
|
|
|
|
return _existing_tool_names()
|
|
|
|
|
|
def discover_mcp_tools() -> List[str]:
|
|
"""Entry point: load config, connect to MCP servers, register tools.
|
|
|
|
Called from ``model_tools`` after ``discover_builtin_tools()``. Safe to call even when
|
|
the ``mcp`` package is not installed (returns empty list).
|
|
|
|
Idempotent for already-connected servers. If some servers failed on a
|
|
previous call, only the missing ones are retried.
|
|
|
|
Returns:
|
|
List of all registered MCP tool names.
|
|
"""
|
|
if not _MCP_AVAILABLE:
|
|
logger.debug("MCP SDK not available -- skipping MCP tool discovery")
|
|
return []
|
|
|
|
servers = _load_mcp_config()
|
|
if not servers:
|
|
logger.debug("No MCP servers configured")
|
|
return []
|
|
|
|
with _lock:
|
|
new_server_names = [
|
|
name
|
|
for name, cfg in servers.items()
|
|
if name not in _servers and _parse_boolish(cfg.get("enabled", True), default=True)
|
|
]
|
|
|
|
tool_names = register_mcp_servers(servers)
|
|
if not new_server_names:
|
|
return tool_names
|
|
|
|
with _lock:
|
|
connected_server_names = [name for name in new_server_names if name in _servers]
|
|
new_tool_count = sum(
|
|
len(getattr(_servers[name], "_registered_tool_names", []))
|
|
for name in connected_server_names
|
|
)
|
|
|
|
failed_count = len(new_server_names) - len(connected_server_names)
|
|
if new_tool_count or failed_count:
|
|
summary = f" MCP: {new_tool_count} tool(s) from {len(connected_server_names)} server(s)"
|
|
if failed_count:
|
|
summary += f" ({failed_count} failed)"
|
|
logger.info(summary)
|
|
|
|
return tool_names
|
|
|
|
|
|
def get_mcp_status() -> List[dict]:
|
|
"""Return status of all configured MCP servers for banner display.
|
|
|
|
Returns a list of dicts with keys: name, transport, tools, connected.
|
|
Includes both successfully connected servers and configured-but-failed ones.
|
|
"""
|
|
result: List[dict] = []
|
|
|
|
# Get configured servers from config
|
|
configured = _load_mcp_config()
|
|
if not configured:
|
|
return result
|
|
|
|
with _lock:
|
|
active_servers = dict(_servers)
|
|
|
|
for name, cfg in configured.items():
|
|
transport = "http" if "url" in cfg else "stdio"
|
|
server = active_servers.get(name)
|
|
if server and server.session is not None:
|
|
entry = {
|
|
"name": name,
|
|
"transport": transport,
|
|
"tools": len(server._registered_tool_names) if hasattr(server, "_registered_tool_names") else len(server._tools),
|
|
"connected": True,
|
|
}
|
|
if server._sampling:
|
|
entry["sampling"] = dict(server._sampling.metrics)
|
|
result.append(entry)
|
|
else:
|
|
result.append({
|
|
"name": name,
|
|
"transport": transport,
|
|
"tools": 0,
|
|
"connected": False,
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
def probe_mcp_server_tools() -> Dict[str, List[tuple]]:
|
|
"""Temporarily connect to configured MCP servers and list their tools.
|
|
|
|
Designed for ``hermes tools`` interactive configuration — connects to each
|
|
enabled server, grabs tool names and descriptions, then disconnects.
|
|
Does NOT register tools in the Hermes registry.
|
|
|
|
Returns:
|
|
Dict mapping server name to list of (tool_name, description) tuples.
|
|
Servers that fail to connect are omitted from the result.
|
|
"""
|
|
if not _MCP_AVAILABLE:
|
|
return {}
|
|
|
|
servers_config = _load_mcp_config()
|
|
if not servers_config:
|
|
return {}
|
|
|
|
enabled = {
|
|
k: v for k, v in servers_config.items()
|
|
if _parse_boolish(v.get("enabled", True), default=True)
|
|
}
|
|
if not enabled:
|
|
return {}
|
|
|
|
_ensure_mcp_loop()
|
|
|
|
result: Dict[str, List[tuple]] = {}
|
|
probed_servers: List[MCPServerTask] = []
|
|
|
|
async def _probe_all():
|
|
names = list(enabled.keys())
|
|
coros = []
|
|
for name, cfg in enabled.items():
|
|
ct = cfg.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
|
coros.append(asyncio.wait_for(_connect_server(name, cfg), timeout=ct))
|
|
|
|
outcomes = await asyncio.gather(*coros, return_exceptions=True)
|
|
|
|
for name, outcome in zip(names, outcomes):
|
|
if isinstance(outcome, Exception):
|
|
logger.debug("Probe: failed to connect to '%s': %s", name, outcome)
|
|
continue
|
|
probed_servers.append(outcome)
|
|
tools = []
|
|
for t in outcome._tools:
|
|
desc = getattr(t, "description", "") or ""
|
|
tools.append((t.name, desc))
|
|
result[name] = tools
|
|
|
|
# Shut down all probed connections
|
|
await asyncio.gather(
|
|
*(s.shutdown() for s in probed_servers),
|
|
return_exceptions=True,
|
|
)
|
|
|
|
try:
|
|
_run_on_mcp_loop(_probe_all(), timeout=120)
|
|
except Exception as exc:
|
|
logger.debug("MCP probe failed: %s", exc)
|
|
finally:
|
|
_stop_mcp_loop()
|
|
|
|
return result
|
|
|
|
|
|
def shutdown_mcp_servers():
|
|
"""Close all MCP server connections and stop the background loop.
|
|
|
|
Each server Task is signalled to exit its ``async with`` block so that
|
|
the anyio cancel-scope cleanup happens in the same Task that opened it.
|
|
All servers are shut down in parallel via ``asyncio.gather``.
|
|
"""
|
|
with _lock:
|
|
servers_snapshot = list(_servers.values())
|
|
|
|
# Fast path: nothing to shut down.
|
|
if not servers_snapshot:
|
|
_stop_mcp_loop()
|
|
return
|
|
|
|
async def _shutdown():
|
|
results = await asyncio.gather(
|
|
*(server.shutdown() for server in servers_snapshot),
|
|
return_exceptions=True,
|
|
)
|
|
for server, result in zip(servers_snapshot, results):
|
|
if isinstance(result, Exception):
|
|
logger.debug(
|
|
"Error closing MCP server '%s': %s", server.name, result,
|
|
)
|
|
with _lock:
|
|
_servers.clear()
|
|
|
|
with _lock:
|
|
loop = _mcp_loop
|
|
if loop is not None and loop.is_running():
|
|
try:
|
|
future = asyncio.run_coroutine_threadsafe(_shutdown(), loop)
|
|
future.result(timeout=15)
|
|
except Exception as exc:
|
|
logger.debug("Error during MCP shutdown: %s", exc)
|
|
|
|
_stop_mcp_loop()
|
|
|
|
|
|
def _kill_orphaned_mcp_children() -> None:
|
|
"""Graceful shutdown of MCP stdio subprocesses that survived loop cleanup.
|
|
|
|
Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL.
|
|
This prevents shared-resource collisions when multiple hermes processes
|
|
run on the same host (each has its own _stdio_pids dict).
|
|
|
|
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
|
"""
|
|
import signal as _signal
|
|
import time as _time
|
|
|
|
with _lock:
|
|
pids = dict(_stdio_pids)
|
|
_stdio_pids.clear()
|
|
|
|
# Phase 1: SIGTERM (graceful)
|
|
for pid, server_name in pids.items():
|
|
try:
|
|
os.kill(pid, _signal.SIGTERM)
|
|
logger.debug("Sent SIGTERM to orphaned MCP process %d (%s)", pid, server_name)
|
|
except (ProcessLookupError, PermissionError, OSError):
|
|
pass
|
|
|
|
# Phase 2: Wait for graceful exit
|
|
_time.sleep(2)
|
|
|
|
# Phase 3: SIGKILL any survivors
|
|
_sigkill = getattr(_signal, "SIGKILL", _signal.SIGTERM)
|
|
for pid, server_name in pids.items():
|
|
try:
|
|
os.kill(pid, 0) # Check if still alive
|
|
os.kill(pid, _sigkill)
|
|
logger.warning(
|
|
"Force-killed MCP process %d (%s) after SIGTERM timeout",
|
|
pid, server_name,
|
|
)
|
|
except (ProcessLookupError, PermissionError, OSError):
|
|
pass # Good — exited after SIGTERM
|
|
|
|
|
|
def _stop_mcp_loop():
|
|
"""Stop the background event loop and join its thread."""
|
|
global _mcp_loop, _mcp_thread
|
|
with _lock:
|
|
loop = _mcp_loop
|
|
thread = _mcp_thread
|
|
_mcp_loop = None
|
|
_mcp_thread = None
|
|
if loop is not None:
|
|
loop.call_soon_threadsafe(loop.stop)
|
|
if thread is not None:
|
|
thread.join(timeout=5)
|
|
try:
|
|
loop.close()
|
|
except Exception:
|
|
pass
|
|
# After closing the loop, any stdio subprocesses that survived the
|
|
# graceful shutdown are now orphaned. Force-kill them.
|
|
_kill_orphaned_mcp_children()
|