mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-03 07:21:54 +00:00
fix(mcp): prevent parallel-safe prefix collisions
This commit is contained in:
parent
874dad5cc1
commit
280c63ce91
3 changed files with 124 additions and 20 deletions
|
|
@ -1161,6 +1161,7 @@ class MCPServerTask:
|
|||
}
|
||||
for tool_name in stale_tool_names:
|
||||
registry.deregister(tool_name)
|
||||
_forget_mcp_tool_server(tool_name)
|
||||
|
||||
# 3. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
|
|
@ -1696,6 +1697,7 @@ class MCPServerTask:
|
|||
self._pending_refresh_tasks.clear()
|
||||
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||
registry.deregister(tool_name)
|
||||
_forget_mcp_tool_server(tool_name)
|
||||
self._registered_tool_names = []
|
||||
self.session = None
|
||||
|
||||
|
|
@ -2066,11 +2068,20 @@ def _handle_session_expired_and_retry(
|
|||
# ``is_mcp_tool_parallel_safe()`` for the parallel-execution check in run_agent.
|
||||
_parallel_safe_servers: set = set()
|
||||
|
||||
# Exact MCP tool-name provenance. MCP tool names are formatted as
|
||||
# ``mcp_{sanitized_server}_{sanitized_tool}``, which is ambiguous when server
|
||||
# names contain underscores (``mcp_a_b_tool`` could be server ``a`` + tool
|
||||
# ``b_tool`` or server ``a_b`` + tool ``tool``). Keep the server component
|
||||
# captured at registration time so parallel safety never relies on prefix
|
||||
# guessing.
|
||||
_mcp_tool_server_names: Dict[str, str] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers, and _stdio_pids.
|
||||
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers,
|
||||
# _mcp_tool_server_names, and _stdio_pids.
|
||||
_lock = threading.Lock()
|
||||
|
||||
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
|
||||
|
|
@ -2953,6 +2964,19 @@ _UTILITY_CAPABILITY_ATTRS = {
|
|||
}
|
||||
|
||||
|
||||
def _track_mcp_tool_server(tool_name: str, server_name: str) -> None:
|
||||
"""Remember the exact MCP server that registered *tool_name*."""
|
||||
safe_server_name = sanitize_mcp_name_component(server_name)
|
||||
with _lock:
|
||||
_mcp_tool_server_names[tool_name] = safe_server_name
|
||||
|
||||
|
||||
def _forget_mcp_tool_server(tool_name: str) -> None:
|
||||
"""Forget MCP server provenance for a deregistered tool."""
|
||||
with _lock:
|
||||
_mcp_tool_server_names.pop(tool_name, None)
|
||||
|
||||
|
||||
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
|
||||
"""Select utility schemas based on config and server capabilities."""
|
||||
tools_filter = config.get("tools") or {}
|
||||
|
|
@ -3087,6 +3111,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
|||
is_async=False,
|
||||
description=schema["description"],
|
||||
)
|
||||
_track_mcp_tool_server(tool_name_prefixed, name)
|
||||
registered_names.append(tool_name_prefixed)
|
||||
|
||||
# Register MCP Resources & Prompts utility tools, filtered by config and
|
||||
|
|
@ -3123,6 +3148,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
|||
is_async=False,
|
||||
description=schema["description"],
|
||||
)
|
||||
_track_mcp_tool_server(util_name, name)
|
||||
registered_names.append(util_name)
|
||||
|
||||
if registered_names:
|
||||
|
|
@ -3307,24 +3333,19 @@ def discover_mcp_tools() -> List[str]:
|
|||
def is_mcp_tool_parallel_safe(tool_name: str) -> bool:
|
||||
"""Check if an MCP tool belongs to a server that supports parallel tool calls.
|
||||
|
||||
MCP tool names follow the pattern ``mcp_{server}_{tool}``. This extracts
|
||||
the server component and checks it against the set of servers whose config
|
||||
includes ``supports_parallel_tool_calls: true``.
|
||||
MCP tool names follow the pattern ``mcp_{server}_{tool}``, but that string
|
||||
shape is ambiguous when server names contain underscores. Use the exact
|
||||
server provenance captured at registration time rather than prefix
|
||||
matching, then check whether that server's config includes
|
||||
``supports_parallel_tool_calls: true``.
|
||||
|
||||
Returns False for non-MCP tools or tools from servers without the flag.
|
||||
"""
|
||||
if not tool_name.startswith("mcp_"):
|
||||
return False
|
||||
# Strip the "mcp_" prefix and extract the server name.
|
||||
# Tool names are: mcp_{sanitized_server}_{sanitized_tool}
|
||||
# We need to check all possible server prefixes because the server name
|
||||
# itself may contain underscores after sanitization.
|
||||
rest = tool_name[4:] # strip "mcp_"
|
||||
with _lock:
|
||||
for server_name in _parallel_safe_servers:
|
||||
if rest.startswith(server_name + "_") and len(rest) > len(server_name) + 1:
|
||||
return True
|
||||
return False
|
||||
server_name = _mcp_tool_server_names.get(tool_name)
|
||||
return bool(server_name and server_name in _parallel_safe_servers)
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue