fix(mcp): prevent parallel-safe prefix collisions

This commit is contained in:
soynchux 2026-05-16 22:05:34 +03:00 committed by Teknium
parent 874dad5cc1
commit 280c63ce91
3 changed files with 124 additions and 20 deletions

View file

@ -2282,9 +2282,11 @@ class TestMcpParallelToolBatch:
def test_mcp_tools_parallel_when_server_opted_in(self):
"""MCP tools from a parallel-safe server can run concurrently."""
from run_agent import _should_parallelize_tool_batch
from tools.mcp_tool import _parallel_safe_servers, _lock
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
with _lock:
_parallel_safe_servers.add("github")
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
_mcp_tool_server_names["mcp_github_search_code"] = "github"
try:
tc1 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c1")
tc2 = _mock_tool_call(name="mcp_github_search_code", arguments='{"q":"test"}', call_id="c2")
@ -2292,13 +2294,16 @@ class TestMcpParallelToolBatch:
finally:
with _lock:
_parallel_safe_servers.discard("github")
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
_mcp_tool_server_names.pop("mcp_github_search_code", None)
def test_mixed_mcp_and_builtin_parallel(self):
"""MCP parallel tools mixed with built-in parallel-safe tools."""
from run_agent import _should_parallelize_tool_batch
from tools.mcp_tool import _parallel_safe_servers, _lock
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names["mcp_docs_search"] = "docs"
try:
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"test"}', call_id="c2")
@ -2306,14 +2311,17 @@ class TestMcpParallelToolBatch:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
def test_mixed_parallel_and_serial_mcp_servers(self):
"""One parallel MCP server + one non-parallel MCP server = sequential."""
from run_agent import _should_parallelize_tool_batch
from tools.mcp_tool import _parallel_safe_servers, _lock
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
with _lock:
_parallel_safe_servers.add("docs")
# "github" is NOT in _parallel_safe_servers
_mcp_tool_server_names["mcp_docs_search"] = "docs"
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
try:
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
tc2 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c2")
@ -2321,6 +2329,8 @@ class TestMcpParallelToolBatch:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
class TestHandleMaxIterations:

View file

@ -3781,16 +3781,26 @@ class TestMcpParallelToolCalls:
def test_is_mcp_tool_parallel_safe_no_servers(self):
"""MCP tool from unknown server returns False."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.clear()
_mcp_tool_server_names.clear()
assert is_mcp_tool_parallel_safe("mcp_docs_search") is False
def test_is_mcp_tool_parallel_safe_with_flag(self):
"""MCP tool from a parallel-safe server returns True."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names["mcp_docs_search"] = "docs"
_mcp_tool_server_names["mcp_docs_read_file"] = "docs"
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
try:
assert is_mcp_tool_parallel_safe("mcp_docs_search") is True
assert is_mcp_tool_parallel_safe("mcp_docs_read_file") is True
@ -3799,23 +3809,86 @@ class TestMcpParallelToolCalls:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
_mcp_tool_server_names.pop("mcp_docs_read_file", None)
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
def test_is_mcp_tool_parallel_safe_server_with_underscores(self):
"""Server names containing underscores are correctly matched."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("my_server")
_mcp_tool_server_names["mcp_my_server_query"] = "my_server"
try:
assert is_mcp_tool_parallel_safe("mcp_my_server_query") is True
finally:
with _lock:
_parallel_safe_servers.discard("my_server")
_mcp_tool_server_names.pop("mcp_my_server_query", None)
def test_is_mcp_tool_parallel_safe_uses_exact_registered_server(self):
"""Ambiguous MCP names must not match a shorter parallel-safe prefix."""
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("a")
_mcp_tool_server_names["mcp_a_search"] = "a"
_mcp_tool_server_names["mcp_a_b_tool"] = "a_b"
try:
assert is_mcp_tool_parallel_safe("mcp_a_search") is True
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
finally:
with _lock:
_parallel_safe_servers.discard("a")
_mcp_tool_server_names.pop("mcp_a_search", None)
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
def test_registered_tool_provenance_prevents_prefix_collision(self):
"""Registration records exact server ownership for ambiguous names."""
from tools.registry import registry
from tools.mcp_tool import (
_mcp_tool_server_names, _parallel_safe_servers,
_register_server_tools, is_mcp_tool_parallel_safe, _lock,
)
server = _make_mock_server(
"a_b",
tools=[_make_mcp_tool("tool", "Ambiguous tool name")],
)
registered = _register_server_tools("a_b", server, {})
try:
assert registered == ["mcp_a_b_tool"]
with _lock:
assert _mcp_tool_server_names["mcp_a_b_tool"] == "a_b"
_parallel_safe_servers.add("a")
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
with _lock:
_parallel_safe_servers.add("a_b")
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is True
finally:
for tool_name in registered:
registry.deregister(tool_name)
with _lock:
_parallel_safe_servers.discard("a")
_parallel_safe_servers.discard("a_b")
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
def test_is_mcp_tool_parallel_safe_no_tool_suffix(self):
"""Tool name that is just 'mcp_{server}' without a tool part returns False."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names.pop("mcp_docs", None)
_mcp_tool_server_names.pop("mcp_docs_", None)
try:
# "mcp_docs" has no tool part after the server name
assert is_mcp_tool_parallel_safe("mcp_docs") is False