mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-30 06:41:51 +00:00
fix(mcp): prevent parallel-safe prefix collisions
This commit is contained in:
parent
874dad5cc1
commit
280c63ce91
3 changed files with 124 additions and 20 deletions
|
|
@ -2282,9 +2282,11 @@ class TestMcpParallelToolBatch:
|
|||
def test_mcp_tools_parallel_when_server_opted_in(self):
|
||||
"""MCP tools from a parallel-safe server can run concurrently."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
from tools.mcp_tool import _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("github")
|
||||
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
|
||||
_mcp_tool_server_names["mcp_github_search_code"] = "github"
|
||||
try:
|
||||
tc1 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="mcp_github_search_code", arguments='{"q":"test"}', call_id="c2")
|
||||
|
|
@ -2292,13 +2294,16 @@ class TestMcpParallelToolBatch:
|
|||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("github")
|
||||
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
|
||||
_mcp_tool_server_names.pop("mcp_github_search_code", None)
|
||||
|
||||
def test_mixed_mcp_and_builtin_parallel(self):
|
||||
"""MCP parallel tools mixed with built-in parallel-safe tools."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
from tools.mcp_tool import _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
_mcp_tool_server_names["mcp_docs_search"] = "docs"
|
||||
try:
|
||||
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"test"}', call_id="c2")
|
||||
|
|
@ -2306,14 +2311,17 @@ class TestMcpParallelToolBatch:
|
|||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
_mcp_tool_server_names.pop("mcp_docs_search", None)
|
||||
|
||||
def test_mixed_parallel_and_serial_mcp_servers(self):
|
||||
"""One parallel MCP server + one non-parallel MCP server = sequential."""
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
from tools.mcp_tool import _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
# "github" is NOT in _parallel_safe_servers
|
||||
_mcp_tool_server_names["mcp_docs_search"] = "docs"
|
||||
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
|
||||
try:
|
||||
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c2")
|
||||
|
|
@ -2321,6 +2329,8 @@ class TestMcpParallelToolBatch:
|
|||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
_mcp_tool_server_names.pop("mcp_docs_search", None)
|
||||
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
|
|
|
|||
|
|
@ -3781,16 +3781,26 @@ class TestMcpParallelToolCalls:
|
|||
|
||||
def test_is_mcp_tool_parallel_safe_no_servers(self):
|
||||
"""MCP tool from unknown server returns False."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import (
|
||||
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
|
||||
_parallel_safe_servers, _lock,
|
||||
)
|
||||
with _lock:
|
||||
_parallel_safe_servers.clear()
|
||||
_mcp_tool_server_names.clear()
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_search") is False
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_with_flag(self):
|
||||
"""MCP tool from a parallel-safe server returns True."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import (
|
||||
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
|
||||
_parallel_safe_servers, _lock,
|
||||
)
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
_mcp_tool_server_names["mcp_docs_search"] = "docs"
|
||||
_mcp_tool_server_names["mcp_docs_read_file"] = "docs"
|
||||
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
|
||||
try:
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_search") is True
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs_read_file") is True
|
||||
|
|
@ -3799,23 +3809,86 @@ class TestMcpParallelToolCalls:
|
|||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("docs")
|
||||
_mcp_tool_server_names.pop("mcp_docs_search", None)
|
||||
_mcp_tool_server_names.pop("mcp_docs_read_file", None)
|
||||
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_server_with_underscores(self):
|
||||
"""Server names containing underscores are correctly matched."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import (
|
||||
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
|
||||
_parallel_safe_servers, _lock,
|
||||
)
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("my_server")
|
||||
_mcp_tool_server_names["mcp_my_server_query"] = "my_server"
|
||||
try:
|
||||
assert is_mcp_tool_parallel_safe("mcp_my_server_query") is True
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("my_server")
|
||||
_mcp_tool_server_names.pop("mcp_my_server_query", None)
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_uses_exact_registered_server(self):
|
||||
"""Ambiguous MCP names must not match a shorter parallel-safe prefix."""
|
||||
from tools.mcp_tool import (
|
||||
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
|
||||
_parallel_safe_servers, _lock,
|
||||
)
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("a")
|
||||
_mcp_tool_server_names["mcp_a_search"] = "a"
|
||||
_mcp_tool_server_names["mcp_a_b_tool"] = "a_b"
|
||||
try:
|
||||
assert is_mcp_tool_parallel_safe("mcp_a_search") is True
|
||||
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
|
||||
finally:
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("a")
|
||||
_mcp_tool_server_names.pop("mcp_a_search", None)
|
||||
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
|
||||
|
||||
def test_registered_tool_provenance_prevents_prefix_collision(self):
|
||||
"""Registration records exact server ownership for ambiguous names."""
|
||||
from tools.registry import registry
|
||||
from tools.mcp_tool import (
|
||||
_mcp_tool_server_names, _parallel_safe_servers,
|
||||
_register_server_tools, is_mcp_tool_parallel_safe, _lock,
|
||||
)
|
||||
|
||||
server = _make_mock_server(
|
||||
"a_b",
|
||||
tools=[_make_mcp_tool("tool", "Ambiguous tool name")],
|
||||
)
|
||||
registered = _register_server_tools("a_b", server, {})
|
||||
try:
|
||||
assert registered == ["mcp_a_b_tool"]
|
||||
with _lock:
|
||||
assert _mcp_tool_server_names["mcp_a_b_tool"] == "a_b"
|
||||
_parallel_safe_servers.add("a")
|
||||
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
|
||||
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("a_b")
|
||||
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is True
|
||||
finally:
|
||||
for tool_name in registered:
|
||||
registry.deregister(tool_name)
|
||||
with _lock:
|
||||
_parallel_safe_servers.discard("a")
|
||||
_parallel_safe_servers.discard("a_b")
|
||||
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
|
||||
|
||||
def test_is_mcp_tool_parallel_safe_no_tool_suffix(self):
|
||||
"""Tool name that is just 'mcp_{server}' without a tool part returns False."""
|
||||
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
|
||||
from tools.mcp_tool import (
|
||||
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
|
||||
_parallel_safe_servers, _lock,
|
||||
)
|
||||
with _lock:
|
||||
_parallel_safe_servers.add("docs")
|
||||
_mcp_tool_server_names.pop("mcp_docs", None)
|
||||
_mcp_tool_server_names.pop("mcp_docs_", None)
|
||||
try:
|
||||
# "mcp_docs" has no tool part after the server name
|
||||
assert is_mcp_tool_parallel_safe("mcp_docs") is False
|
||||
|
|
|
|||
|
|
@ -1161,6 +1161,7 @@ class MCPServerTask:
|
|||
}
|
||||
for tool_name in stale_tool_names:
|
||||
registry.deregister(tool_name)
|
||||
_forget_mcp_tool_server(tool_name)
|
||||
|
||||
# 3. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
|
|
@ -1696,6 +1697,7 @@ class MCPServerTask:
|
|||
self._pending_refresh_tasks.clear()
|
||||
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||
registry.deregister(tool_name)
|
||||
_forget_mcp_tool_server(tool_name)
|
||||
self._registered_tool_names = []
|
||||
self.session = None
|
||||
|
||||
|
|
@ -2066,11 +2068,20 @@ def _handle_session_expired_and_retry(
|
|||
# ``is_mcp_tool_parallel_safe()`` for the parallel-execution check in run_agent.
|
||||
_parallel_safe_servers: set = set()
|
||||
|
||||
# Exact MCP tool-name provenance. MCP tool names are formatted as
|
||||
# ``mcp_{sanitized_server}_{sanitized_tool}``, which is ambiguous when server
|
||||
# names contain underscores (``mcp_a_b_tool`` could be server ``a`` + tool
|
||||
# ``b_tool`` or server ``a_b`` + tool ``tool``). Keep the server component
|
||||
# captured at registration time so parallel safety never relies on prefix
|
||||
# guessing.
|
||||
_mcp_tool_server_names: Dict[str, str] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers, and _stdio_pids.
|
||||
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers,
|
||||
# _mcp_tool_server_names, and _stdio_pids.
|
||||
_lock = threading.Lock()
|
||||
|
||||
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
|
||||
|
|
@ -2953,6 +2964,19 @@ _UTILITY_CAPABILITY_ATTRS = {
|
|||
}
|
||||
|
||||
|
||||
def _track_mcp_tool_server(tool_name: str, server_name: str) -> None:
|
||||
"""Remember the exact MCP server that registered *tool_name*."""
|
||||
safe_server_name = sanitize_mcp_name_component(server_name)
|
||||
with _lock:
|
||||
_mcp_tool_server_names[tool_name] = safe_server_name
|
||||
|
||||
|
||||
def _forget_mcp_tool_server(tool_name: str) -> None:
|
||||
"""Forget MCP server provenance for a deregistered tool."""
|
||||
with _lock:
|
||||
_mcp_tool_server_names.pop(tool_name, None)
|
||||
|
||||
|
||||
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
|
||||
"""Select utility schemas based on config and server capabilities."""
|
||||
tools_filter = config.get("tools") or {}
|
||||
|
|
@ -3087,6 +3111,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
|||
is_async=False,
|
||||
description=schema["description"],
|
||||
)
|
||||
_track_mcp_tool_server(tool_name_prefixed, name)
|
||||
registered_names.append(tool_name_prefixed)
|
||||
|
||||
# Register MCP Resources & Prompts utility tools, filtered by config and
|
||||
|
|
@ -3123,6 +3148,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
|
|||
is_async=False,
|
||||
description=schema["description"],
|
||||
)
|
||||
_track_mcp_tool_server(util_name, name)
|
||||
registered_names.append(util_name)
|
||||
|
||||
if registered_names:
|
||||
|
|
@ -3307,24 +3333,19 @@ def discover_mcp_tools() -> List[str]:
|
|||
def is_mcp_tool_parallel_safe(tool_name: str) -> bool:
|
||||
"""Check if an MCP tool belongs to a server that supports parallel tool calls.
|
||||
|
||||
MCP tool names follow the pattern ``mcp_{server}_{tool}``. This extracts
|
||||
the server component and checks it against the set of servers whose config
|
||||
includes ``supports_parallel_tool_calls: true``.
|
||||
MCP tool names follow the pattern ``mcp_{server}_{tool}``, but that string
|
||||
shape is ambiguous when server names contain underscores. Use the exact
|
||||
server provenance captured at registration time rather than prefix
|
||||
matching, then check whether that server's config includes
|
||||
``supports_parallel_tool_calls: true``.
|
||||
|
||||
Returns False for non-MCP tools or tools from servers without the flag.
|
||||
"""
|
||||
if not tool_name.startswith("mcp_"):
|
||||
return False
|
||||
# Strip the "mcp_" prefix and extract the server name.
|
||||
# Tool names are: mcp_{sanitized_server}_{sanitized_tool}
|
||||
# We need to check all possible server prefixes because the server name
|
||||
# itself may contain underscores after sanitization.
|
||||
rest = tool_name[4:] # strip "mcp_"
|
||||
with _lock:
|
||||
for server_name in _parallel_safe_servers:
|
||||
if rest.startswith(server_name + "_") and len(rest) > len(server_name) + 1:
|
||||
return True
|
||||
return False
|
||||
server_name = _mcp_tool_server_names.get(tool_name)
|
||||
return bool(server_name and server_name in _parallel_safe_servers)
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue