diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 11b58e5faa1..a72359227a6 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2282,9 +2282,11 @@ class TestMcpParallelToolBatch: def test_mcp_tools_parallel_when_server_opted_in(self): """MCP tools from a parallel-safe server can run concurrently.""" from run_agent import _should_parallelize_tool_batch - from tools.mcp_tool import _parallel_safe_servers, _lock + from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock with _lock: _parallel_safe_servers.add("github") + _mcp_tool_server_names["mcp_github_list_repos"] = "github" + _mcp_tool_server_names["mcp_github_search_code"] = "github" try: tc1 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c1") tc2 = _mock_tool_call(name="mcp_github_search_code", arguments='{"q":"test"}', call_id="c2") @@ -2292,13 +2294,16 @@ class TestMcpParallelToolBatch: finally: with _lock: _parallel_safe_servers.discard("github") + _mcp_tool_server_names.pop("mcp_github_list_repos", None) + _mcp_tool_server_names.pop("mcp_github_search_code", None) def test_mixed_mcp_and_builtin_parallel(self): """MCP parallel tools mixed with built-in parallel-safe tools.""" from run_agent import _should_parallelize_tool_batch - from tools.mcp_tool import _parallel_safe_servers, _lock + from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock with _lock: _parallel_safe_servers.add("docs") + _mcp_tool_server_names["mcp_docs_search"] = "docs" try: tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1") tc2 = _mock_tool_call(name="web_search", arguments='{"query":"test"}', call_id="c2") @@ -2306,14 +2311,17 @@ class TestMcpParallelToolBatch: finally: with _lock: _parallel_safe_servers.discard("docs") + _mcp_tool_server_names.pop("mcp_docs_search", None) def test_mixed_parallel_and_serial_mcp_servers(self): """One parallel MCP server + one non-parallel MCP server = sequential.""" from run_agent import _should_parallelize_tool_batch - from tools.mcp_tool import _parallel_safe_servers, _lock + from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock with _lock: _parallel_safe_servers.add("docs") # "github" is NOT in _parallel_safe_servers + _mcp_tool_server_names["mcp_docs_search"] = "docs" + _mcp_tool_server_names["mcp_github_list_repos"] = "github" try: tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1") tc2 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c2") @@ -2321,6 +2329,8 @@ class TestMcpParallelToolBatch: finally: with _lock: _parallel_safe_servers.discard("docs") + _mcp_tool_server_names.pop("mcp_docs_search", None) + _mcp_tool_server_names.pop("mcp_github_list_repos", None) class TestHandleMaxIterations: diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 0a094eb5467..3212a350c37 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -3781,16 +3781,26 @@ class TestMcpParallelToolCalls: def test_is_mcp_tool_parallel_safe_no_servers(self): """MCP tool from unknown server returns False.""" - from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock + from tools.mcp_tool import ( + is_mcp_tool_parallel_safe, _mcp_tool_server_names, + _parallel_safe_servers, _lock, + ) with _lock: _parallel_safe_servers.clear() + _mcp_tool_server_names.clear() assert is_mcp_tool_parallel_safe("mcp_docs_search") is False def test_is_mcp_tool_parallel_safe_with_flag(self): """MCP tool from a parallel-safe server returns True.""" - from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock + from tools.mcp_tool import ( + is_mcp_tool_parallel_safe, _mcp_tool_server_names, + _parallel_safe_servers, _lock, + ) with _lock: _parallel_safe_servers.add("docs") + _mcp_tool_server_names["mcp_docs_search"] = "docs" + _mcp_tool_server_names["mcp_docs_read_file"] = "docs" + _mcp_tool_server_names["mcp_github_list_repos"] = "github" try: assert is_mcp_tool_parallel_safe("mcp_docs_search") is True assert is_mcp_tool_parallel_safe("mcp_docs_read_file") is True @@ -3799,23 +3809,86 @@ class TestMcpParallelToolCalls: finally: with _lock: _parallel_safe_servers.discard("docs") + _mcp_tool_server_names.pop("mcp_docs_search", None) + _mcp_tool_server_names.pop("mcp_docs_read_file", None) + _mcp_tool_server_names.pop("mcp_github_list_repos", None) def test_is_mcp_tool_parallel_safe_server_with_underscores(self): """Server names containing underscores are correctly matched.""" - from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock + from tools.mcp_tool import ( + is_mcp_tool_parallel_safe, _mcp_tool_server_names, + _parallel_safe_servers, _lock, + ) with _lock: _parallel_safe_servers.add("my_server") + _mcp_tool_server_names["mcp_my_server_query"] = "my_server" try: assert is_mcp_tool_parallel_safe("mcp_my_server_query") is True finally: with _lock: _parallel_safe_servers.discard("my_server") + _mcp_tool_server_names.pop("mcp_my_server_query", None) + + def test_is_mcp_tool_parallel_safe_uses_exact_registered_server(self): + """Ambiguous MCP names must not match a shorter parallel-safe prefix.""" + from tools.mcp_tool import ( + is_mcp_tool_parallel_safe, _mcp_tool_server_names, + _parallel_safe_servers, _lock, + ) + with _lock: + _parallel_safe_servers.add("a") + _mcp_tool_server_names["mcp_a_search"] = "a" + _mcp_tool_server_names["mcp_a_b_tool"] = "a_b" + try: + assert is_mcp_tool_parallel_safe("mcp_a_search") is True + assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False + finally: + with _lock: + _parallel_safe_servers.discard("a") + _mcp_tool_server_names.pop("mcp_a_search", None) + _mcp_tool_server_names.pop("mcp_a_b_tool", None) + + def test_registered_tool_provenance_prevents_prefix_collision(self): + """Registration records exact server ownership for ambiguous names.""" + from tools.registry import registry + from tools.mcp_tool import ( + _mcp_tool_server_names, _parallel_safe_servers, + _register_server_tools, is_mcp_tool_parallel_safe, _lock, + ) + + server = _make_mock_server( + "a_b", + tools=[_make_mcp_tool("tool", "Ambiguous tool name")], + ) + registered = _register_server_tools("a_b", server, {}) + try: + assert registered == ["mcp_a_b_tool"] + with _lock: + assert _mcp_tool_server_names["mcp_a_b_tool"] == "a_b" + _parallel_safe_servers.add("a") + assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False + + with _lock: + _parallel_safe_servers.add("a_b") + assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is True + finally: + for tool_name in registered: + registry.deregister(tool_name) + with _lock: + _parallel_safe_servers.discard("a") + _parallel_safe_servers.discard("a_b") + _mcp_tool_server_names.pop("mcp_a_b_tool", None) def test_is_mcp_tool_parallel_safe_no_tool_suffix(self): """Tool name that is just 'mcp_{server}' without a tool part returns False.""" - from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock + from tools.mcp_tool import ( + is_mcp_tool_parallel_safe, _mcp_tool_server_names, + _parallel_safe_servers, _lock, + ) with _lock: _parallel_safe_servers.add("docs") + _mcp_tool_server_names.pop("mcp_docs", None) + _mcp_tool_server_names.pop("mcp_docs_", None) try: # "mcp_docs" has no tool part after the server name assert is_mcp_tool_parallel_safe("mcp_docs") is False diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 9cec72524af..e1d87389d42 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1161,6 +1161,7 @@ class MCPServerTask: } for tool_name in stale_tool_names: registry.deregister(tool_name) + _forget_mcp_tool_server(tool_name) # 3. Re-register with fresh tool list self._tools = new_mcp_tools @@ -1696,6 +1697,7 @@ class MCPServerTask: self._pending_refresh_tasks.clear() for tool_name in list(getattr(self, "_registered_tool_names", [])): registry.deregister(tool_name) + _forget_mcp_tool_server(tool_name) self._registered_tool_names = [] self.session = None @@ -2066,11 +2068,20 @@ def _handle_session_expired_and_retry( # ``is_mcp_tool_parallel_safe()`` for the parallel-execution check in run_agent. _parallel_safe_servers: set = set() +# Exact MCP tool-name provenance. MCP tool names are formatted as +# ``mcp_{sanitized_server}_{sanitized_tool}``, which is ambiguous when server +# names contain underscores (``mcp_a_b_tool`` could be server ``a`` + tool +# ``b_tool`` or server ``a_b`` + tool ``tool``). Keep the server component +# captured at registration time so parallel safety never relies on prefix +# guessing. +_mcp_tool_server_names: Dict[str, str] = {} + # Dedicated event loop running in a background daemon thread. _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None -# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers, and _stdio_pids. +# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers, +# _mcp_tool_server_names, and _stdio_pids. _lock = threading.Lock() # PIDs of stdio MCP server subprocesses. Tracked so we can force-kill @@ -2953,6 +2964,19 @@ _UTILITY_CAPABILITY_ATTRS = { } +def _track_mcp_tool_server(tool_name: str, server_name: str) -> None: + """Remember the exact MCP server that registered *tool_name*.""" + safe_server_name = sanitize_mcp_name_component(server_name) + with _lock: + _mcp_tool_server_names[tool_name] = safe_server_name + + +def _forget_mcp_tool_server(tool_name: str) -> None: + """Forget MCP server provenance for a deregistered tool.""" + with _lock: + _mcp_tool_server_names.pop(tool_name, None) + + def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]: """Select utility schemas based on config and server capabilities.""" tools_filter = config.get("tools") or {} @@ -3087,6 +3111,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li is_async=False, description=schema["description"], ) + _track_mcp_tool_server(tool_name_prefixed, name) registered_names.append(tool_name_prefixed) # Register MCP Resources & Prompts utility tools, filtered by config and @@ -3123,6 +3148,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li is_async=False, description=schema["description"], ) + _track_mcp_tool_server(util_name, name) registered_names.append(util_name) if registered_names: @@ -3307,24 +3333,19 @@ def discover_mcp_tools() -> List[str]: def is_mcp_tool_parallel_safe(tool_name: str) -> bool: """Check if an MCP tool belongs to a server that supports parallel tool calls. - MCP tool names follow the pattern ``mcp_{server}_{tool}``. This extracts - the server component and checks it against the set of servers whose config - includes ``supports_parallel_tool_calls: true``. + MCP tool names follow the pattern ``mcp_{server}_{tool}``, but that string + shape is ambiguous when server names contain underscores. Use the exact + server provenance captured at registration time rather than prefix + matching, then check whether that server's config includes + ``supports_parallel_tool_calls: true``. Returns False for non-MCP tools or tools from servers without the flag. """ if not tool_name.startswith("mcp_"): return False - # Strip the "mcp_" prefix and extract the server name. - # Tool names are: mcp_{sanitized_server}_{sanitized_tool} - # We need to check all possible server prefixes because the server name - # itself may contain underscores after sanitization. - rest = tool_name[4:] # strip "mcp_" with _lock: - for server_name in _parallel_safe_servers: - if rest.startswith(server_name + "_") and len(rest) > len(server_name) + 1: - return True - return False + server_name = _mcp_tool_server_names.get(tool_name) + return bool(server_name and server_name in _parallel_safe_servers) def get_mcp_status() -> List[dict]: