feat: devex help, add Makefile, ruff, pre-commit, and modernize CI

This commit is contained in:
Brooklyn Nicholson 2026-03-09 20:36:51 -05:00
parent 172a38c344
commit f4d7e6a29e
111 changed files with 11655 additions and 10200 deletions

View file

@ -77,7 +77,7 @@ import os
import re
import threading
import time
from typing import Any, Dict, List, Optional
from typing import Any
logger = logging.getLogger(__name__)
@ -91,9 +91,11 @@ _MCP_SAMPLING_TYPES = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
try:
from mcp.client.streamable_http import streamablehttp_client
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
@ -108,6 +110,7 @@ try:
TextContent,
ToolUseContent,
)
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
@ -118,27 +121,36 @@ except ImportError:
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
_MAX_RECONNECT_RETRIES = 5
_MAX_BACKOFF_SECONDS = 60
# Environment variables that are safe to pass to stdio subprocesses
_SAFE_ENV_KEYS = frozenset({
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
})
_SAFE_ENV_KEYS = frozenset(
{
"PATH",
"HOME",
"USER",
"LANG",
"LC_ALL",
"TERM",
"SHELL",
"TMPDIR",
}
)
# Regex for credential patterns to strip from error messages
_CREDENTIAL_PATTERN = re.compile(
r"(?:"
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r")",
re.IGNORECASE,
)
@ -148,7 +160,8 @@ _CREDENTIAL_PATTERN = re.compile(
# Security helpers
# ---------------------------------------------------------------------------
def _build_safe_env(user_env: Optional[dict]) -> dict:
def _build_safe_env(user_env: dict | None) -> dict:
"""Build a filtered environment dict for stdio subprocesses.
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
@ -180,6 +193,7 @@ def _sanitize_error(text: str) -> str:
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
def _safe_numeric(value, default, coerce=int, minimum=1):
"""Coerce a config value to a numeric type, returning *default* on failure.
@ -216,18 +230,22 @@ class SamplingHandler:
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
self.max_tool_rounds = _safe_numeric(
config.get("max_tool_rounds", 5), 5, int, minimum=0,
config.get("max_tool_rounds", 5),
5,
int,
minimum=0,
)
self.model_override = config.get("model")
self.allowed_models = config.get("allowed_models", [])
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
self.audit_level = _log_levels.get(
str(config.get("log_level", "info")).lower(), logging.INFO,
str(config.get("log_level", "info")).lower(),
logging.INFO,
)
# Per-instance state
self._rate_timestamps: List[float] = []
self._rate_timestamps: list[float] = []
self._tool_loop_count = 0
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
@ -245,7 +263,7 @@ class SamplingHandler:
# -- Model resolution ----------------------------------------------------
def _resolve_model(self, preferences) -> Optional[str]:
def _resolve_model(self, preferences) -> str | None:
"""Config override > server hint > None (use default)."""
if self.model_override:
return self.model_override
@ -265,7 +283,7 @@ class SamplingHandler:
items = block.content if isinstance(block.content, list) else [block.content]
return "\n".join(item.text for item in items if hasattr(item, "text"))
def _convert_messages(self, params) -> List[dict]:
def _convert_messages(self, params) -> list[dict]:
"""Convert MCP SamplingMessages to OpenAI format.
Uses ``msg.content_as_list`` (SDK helper) so single-block and
@ -273,37 +291,47 @@ class SamplingHandler:
with ``isinstance`` on real SDK types when available, falling back
to duck-typing via ``hasattr`` for compatibility.
"""
messages: List[dict] = []
messages: list[dict] = []
for msg in params.messages:
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
msg.content if isinstance(msg.content, list) else [msg.content]
blocks = (
msg.content_as_list
if hasattr(msg, "content_as_list")
else (msg.content if isinstance(msg.content, list) else [msg.content])
)
# Separate blocks by kind
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
tool_uses = [
b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")
]
content_blocks = [
b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))
]
# Emit tool result messages (role: tool)
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
})
messages.append(
{
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
}
)
# Emit assistant tool_calls message
if tool_uses:
tc_list = []
for tu in tool_uses:
tc_list.append({
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
})
tc_list.append(
{
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
}
)
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
# Include any accompanying text
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
@ -320,10 +348,12 @@ class SamplingHandler:
if hasattr(block, "text"):
parts.append({"type": "text", "text": block.text})
elif hasattr(block, "data") and hasattr(block, "mimeType"):
parts.append({
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
})
parts.append(
{
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
}
)
else:
logger.warning(
"Unsupported sampling content block type: %s (skipped)",
@ -352,16 +382,13 @@ class SamplingHandler:
# Tool loop governance
if self.max_tool_rounds == 0:
self._tool_loop_count = 0
return self._error(
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
)
return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)")
self._tool_loop_count += 1
if self._tool_loop_count > self.max_tool_rounds:
self._tool_loop_count = 0
return self._error(
f"Tool loop limit exceeded for server '{self.server_name}' "
f"(max {self.max_tool_rounds} rounds)"
f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)"
)
content_blocks = []
@ -372,25 +399,28 @@ class SamplingHandler:
parsed = json.loads(args)
except (json.JSONDecodeError, ValueError):
logger.warning(
"MCP server '%s': malformed tool_calls arguments "
"from LLM (wrapping as raw): %.100s",
self.server_name, args,
"MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s",
self.server_name,
args,
)
parsed = {"_raw": args}
else:
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
content_blocks.append(ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
))
content_blocks.append(
ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
)
)
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
self.server_name, response.model,
self.server_name,
response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
len(content_blocks),
)
@ -410,7 +440,8 @@ class SamplingHandler:
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s",
self.server_name, response.model,
self.server_name,
response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
)
@ -445,12 +476,12 @@ class SamplingHandler:
if not self._check_rate_limit():
logger.warning(
"MCP server '%s' sampling rate limit exceeded (%d/min)",
self.server_name, self.max_rpm,
self.server_name,
self.max_rpm,
)
self.metrics["errors"] += 1
return self._error(
f"Sampling rate limit exceeded for server '{self.server_name}' "
f"({self.max_rpm} requests/minute)"
f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)"
)
# Resolve model
@ -458,6 +489,7 @@ class SamplingHandler:
# Get auxiliary LLM client
from agent.auxiliary_client import get_text_auxiliary_client
client, default_model = get_text_auxiliary_client()
if client is None:
self.metrics["errors"] += 1
@ -469,7 +501,8 @@ class SamplingHandler:
if self.allowed_models and resolved_model not in self.allowed_models:
logger.warning(
"MCP server '%s' requested model '%s' not in allowed_models",
self.server_name, resolved_model,
self.server_name,
resolved_model,
)
self.metrics["errors"] += 1
return self._error(
@ -515,7 +548,10 @@ class SamplingHandler:
logger.log(
self.audit_level,
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
self.server_name, resolved_model, max_tokens, len(messages),
self.server_name,
resolved_model,
max_tokens,
len(messages),
)
# Offload sync LLM call to thread (non-blocking)
@ -524,19 +560,15 @@ class SamplingHandler:
try:
response = await asyncio.wait_for(
asyncio.to_thread(_sync_call), timeout=self.timeout,
asyncio.to_thread(_sync_call),
timeout=self.timeout,
)
except asyncio.TimeoutError:
except TimeoutError:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call timed out after {self.timeout}s "
f"for server '{self.server_name}'"
)
return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'")
except Exception as exc:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
)
return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}")
# Track metrics
choice = response.choices[0]
@ -546,11 +578,7 @@ class SamplingHandler:
self.metrics["tokens_used"] += total_tokens
# Dispatch based on response type
if (
choice.finish_reason == "tool_calls"
and hasattr(choice.message, "tool_calls")
and choice.message.tool_calls
):
if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
return self._build_tool_use_result(choice, response)
return self._build_text_result(choice, response)
@ -560,6 +588,7 @@ class SamplingHandler:
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
class MCPServerTask:
"""Manages a single MCP server connection in a dedicated asyncio Task.
@ -571,22 +600,29 @@ class MCPServerTask:
"""
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"name",
"session",
"tool_timeout",
"_task",
"_ready",
"_shutdown_event",
"_tools",
"_error",
"_config",
"_sampling",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self.session: Any | None = None
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
self._task: Optional[asyncio.Task] = None
self._task: asyncio.Task | None = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
self._error: Exception | None = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
self._sampling: SamplingHandler | None = None
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@ -599,9 +635,7 @@ class MCPServerTask:
user_env = config.get("env")
if not command:
raise ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
raise ValueError(f"MCP server '{self.name}' has no 'command' in config")
safe_env = _build_safe_env(user_env)
server_params = StdioServerParameters(
@ -650,11 +684,7 @@ class MCPServerTask:
if self.session is None:
return
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
self._tools = tools_result.tools if hasattr(tools_result, "tools") else []
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
@ -704,24 +734,28 @@ class MCPServerTask:
if self._shutdown_event.is_set():
logger.debug(
"MCP server '%s' disconnected during shutdown: %s",
self.name, exc,
self.name,
exc,
)
return
retries += 1
if retries > _MAX_RECONNECT_RETRIES:
logger.warning(
"MCP server '%s' failed after %d reconnection attempts, "
"giving up: %s",
self.name, _MAX_RECONNECT_RETRIES, exc,
"MCP server '%s' failed after %d reconnection attempts, giving up: %s",
self.name,
_MAX_RECONNECT_RETRIES,
exc,
)
return
logger.warning(
"MCP server '%s' connection lost (attempt %d/%d), "
"reconnecting in %.0fs: %s",
self.name, retries, _MAX_RECONNECT_RETRIES,
backoff, exc,
"MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s",
self.name,
retries,
_MAX_RECONNECT_RETRIES,
backoff,
exc,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
@ -745,7 +779,7 @@ class MCPServerTask:
if self._task and not self._task.done():
try:
await asyncio.wait_for(self._task, timeout=10)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"MCP server '%s' shutdown timed out, cancelling task",
self.name,
@ -762,11 +796,11 @@ class MCPServerTask:
# Module-level state
# ---------------------------------------------------------------------------
_servers: Dict[str, MCPServerTask] = {}
_servers: dict[str, MCPServerTask] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
_mcp_loop: asyncio.AbstractEventLoop | None = None
_mcp_thread: threading.Thread | None = None
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
_lock = threading.Lock()
@ -801,7 +835,8 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
# Config loading
# ---------------------------------------------------------------------------
def _load_mcp_config() -> Dict[str, dict]:
def _load_mcp_config() -> dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: server_config}`` or empty dict.
@ -811,6 +846,7 @@ def _load_mcp_config() -> Dict[str, dict]:
"""
try:
from hermes_cli.config import load_config
config = load_config()
servers = config.get("mcp_servers")
if not servers or not isinstance(servers, dict):
@ -825,6 +861,7 @@ def _load_mcp_config() -> Dict[str, dict]:
# Server connection helper
# ---------------------------------------------------------------------------
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
@ -845,6 +882,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
"""Return a sync handler that calls an MCP tool via the background loop.
@ -856,27 +894,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.call_tool(tool_name, arguments=args)
# MCP CallToolResult has .content (list of content blocks) and .isError
if result.isError:
error_text = ""
for block in (result.content or []):
for block in result.content or []:
if hasattr(block, "text"):
error_text += block.text
return json.dumps({
"error": _sanitize_error(
error_text or "MCP tool returned an error"
)
})
return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")})
# Collect text from content blocks
parts: List[str] = []
for block in (result.content or []):
parts: list[str] = []
for block in result.content or []:
if hasattr(block, "text"):
parts.append(block.text)
return json.dumps({"result": "\n".join(parts) if parts else ""})
@ -886,13 +918,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
except Exception as exc:
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
server_name,
tool_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -904,14 +934,12 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.list_resources()
resources = []
for r in (result.resources if hasattr(result, "resources") else []):
for r in result.resources if hasattr(result, "resources") else []:
entry = {}
if hasattr(r, "uri"):
entry["uri"] = str(r.uri)
@ -928,13 +956,11 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/list_resources failed: %s", server_name, exc,
"MCP %s/list_resources failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -946,9 +972,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
uri = args.get("uri")
if not uri:
@ -957,7 +981,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
async def _call():
result = await server.session.read_resource(uri)
# read_resource returns ReadResourceResult with .contents list
parts: List[str] = []
parts: list[str] = []
contents = result.contents if hasattr(result, "contents") else []
for block in contents:
if hasattr(block, "text"):
@ -970,13 +994,11 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/read_resource failed: %s", server_name, exc,
"MCP %s/read_resource failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -988,14 +1010,12 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.list_prompts()
prompts = []
for p in (result.prompts if hasattr(result, "prompts") else []):
for p in result.prompts if hasattr(result, "prompts") else []:
entry = {}
if hasattr(p, "name"):
entry["name"] = p.name
@ -1017,13 +1037,11 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/list_prompts failed: %s", server_name, exc,
"MCP %s/list_prompts failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -1035,9 +1053,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
name = args.get("name")
if not name:
@ -1048,7 +1064,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
result = await server.session.get_prompt(name, arguments=arguments)
# GetPromptResult has .messages list
messages = []
for msg in (result.messages if hasattr(result, "messages") else []):
for msg in result.messages if hasattr(result, "messages") else []:
entry = {}
if hasattr(msg, "role"):
entry["role"] = msg.role
@ -1070,13 +1086,11 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/get_prompt failed: %s", server_name, exc,
"MCP %s/get_prompt failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -1096,6 +1110,7 @@ def _make_check_fn(server_name: str):
# Discovery & registration
# ---------------------------------------------------------------------------
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
"""Convert an MCP tool listing to the Hermes registry schema format.
@ -1114,14 +1129,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
return {
"name": prefixed_name,
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
"parameters": mcp_tool.inputSchema
if mcp_tool.inputSchema
else {
"type": "object",
"properties": {},
},
}
def _build_utility_schemas(server_name: str) -> List[dict]:
def _build_utility_schemas(server_name: str) -> list[dict]:
"""Build schemas for the MCP utility tools (resources & prompts).
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
@ -1192,9 +1209,9 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
]
def _existing_tool_names() -> List[str]:
def _existing_tool_names() -> list[str]:
"""Return tool names for all currently connected servers."""
names: List[str] = []
names: list[str] = []
for sname, server in _servers.items():
for mcp_tool in server._tools:
schema = _convert_mcp_schema(sname, mcp_tool)
@ -1205,7 +1222,7 @@ def _existing_tool_names() -> List[str]:
return names
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
async def _discover_and_register_server(name: str, config: dict) -> list[str]:
"""Connect to a single MCP server, discover tools, and register them.
Also registers utility tools for MCP Resources and Prompts support
@ -1224,7 +1241,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
with _lock:
_servers[name] = server
registered_names: List[str] = []
registered_names: list[str] = []
toolset_name = f"mcp-{name}"
for mcp_tool in server._tools:
@ -1277,7 +1294,9 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
transport_type = "HTTP" if "url" in config else "stdio"
logger.info(
"MCP server '%s' (%s): registered %d tool(s): %s",
name, transport_type, len(registered_names),
name,
transport_type,
len(registered_names),
", ".join(registered_names),
)
return registered_names
@ -1287,7 +1306,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
# Public API
# ---------------------------------------------------------------------------
def discover_mcp_tools() -> List[str]:
def discover_mcp_tools() -> list[str]:
"""Entry point: load config, connect to MCP servers, register tools.
Called from ``model_tools._discover_tools()``. Safe to call even when
@ -1318,12 +1338,12 @@ def discover_mcp_tools() -> List[str]:
# Start the background event loop for MCP connections
_ensure_mcp_loop()
all_tools: List[str] = []
all_tools: list[str] = []
failed_count = 0
async def _discover_one(name: str, cfg: dict) -> List[str]:
async def _discover_one(name: str, cfg: dict) -> list[str]:
"""Connect to a single server and return its registered tool names."""
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}")
try:
registered = await _discover_and_register_server(name, cfg)
transport_type = "HTTP" if "url" in cfg else "stdio"
@ -1331,7 +1351,8 @@ def discover_mcp_tools() -> List[str]:
except Exception as exc:
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, exc,
name,
exc,
)
return []
@ -1358,6 +1379,7 @@ def discover_mcp_tools() -> List[str]:
if all_tools:
# Dynamically inject into all hermes-* platform toolsets
from toolsets import TOOLSETS
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
for tool_name in all_tools:
@ -1377,13 +1399,13 @@ def discover_mcp_tools() -> List[str]:
return _existing_tool_names()
def get_mcp_status() -> List[dict]:
def get_mcp_status() -> list[dict]:
"""Return status of all configured MCP servers for banner display.
Returns a list of dicts with keys: name, transport, tools, connected.
Includes both successfully connected servers and configured-but-failed ones.
"""
result: List[dict] = []
result: list[dict] = []
# Get configured servers from config
configured = _load_mcp_config()
@ -1407,12 +1429,14 @@ def get_mcp_status() -> List[dict]:
entry["sampling"] = dict(server._sampling.metrics)
result.append(entry)
else:
result.append({
"name": name,
"transport": transport,
"tools": 0,
"connected": False,
})
result.append(
{
"name": name,
"transport": transport,
"tools": 0,
"connected": False,
}
)
return result
@ -1440,7 +1464,9 @@ def shutdown_mcp_servers():
for server, result in zip(servers_snapshot, results):
if isinstance(result, Exception):
logger.debug(
"Error closing MCP server '%s': %s", server.name, result,
"Error closing MCP server '%s': %s",
server.name,
result,
)
with _lock:
_servers.clear()