fix(mcp): prevent parallel-safe prefix collisions

This commit is contained in:
soynchux 2026-05-16 22:05:34 +03:00 committed by Teknium
parent 874dad5cc1
commit 280c63ce91
3 changed files with 124 additions and 20 deletions

View file

@ -3781,16 +3781,26 @@ class TestMcpParallelToolCalls:
def test_is_mcp_tool_parallel_safe_no_servers(self):
"""MCP tool from unknown server returns False."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.clear()
_mcp_tool_server_names.clear()
assert is_mcp_tool_parallel_safe("mcp_docs_search") is False
def test_is_mcp_tool_parallel_safe_with_flag(self):
"""MCP tool from a parallel-safe server returns True."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names["mcp_docs_search"] = "docs"
_mcp_tool_server_names["mcp_docs_read_file"] = "docs"
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
try:
assert is_mcp_tool_parallel_safe("mcp_docs_search") is True
assert is_mcp_tool_parallel_safe("mcp_docs_read_file") is True
@ -3799,23 +3809,86 @@ class TestMcpParallelToolCalls:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
_mcp_tool_server_names.pop("mcp_docs_read_file", None)
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
def test_is_mcp_tool_parallel_safe_server_with_underscores(self):
"""Server names containing underscores are correctly matched."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("my_server")
_mcp_tool_server_names["mcp_my_server_query"] = "my_server"
try:
assert is_mcp_tool_parallel_safe("mcp_my_server_query") is True
finally:
with _lock:
_parallel_safe_servers.discard("my_server")
_mcp_tool_server_names.pop("mcp_my_server_query", None)
def test_is_mcp_tool_parallel_safe_uses_exact_registered_server(self):
"""Ambiguous MCP names must not match a shorter parallel-safe prefix."""
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("a")
_mcp_tool_server_names["mcp_a_search"] = "a"
_mcp_tool_server_names["mcp_a_b_tool"] = "a_b"
try:
assert is_mcp_tool_parallel_safe("mcp_a_search") is True
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
finally:
with _lock:
_parallel_safe_servers.discard("a")
_mcp_tool_server_names.pop("mcp_a_search", None)
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
def test_registered_tool_provenance_prevents_prefix_collision(self):
"""Registration records exact server ownership for ambiguous names."""
from tools.registry import registry
from tools.mcp_tool import (
_mcp_tool_server_names, _parallel_safe_servers,
_register_server_tools, is_mcp_tool_parallel_safe, _lock,
)
server = _make_mock_server(
"a_b",
tools=[_make_mcp_tool("tool", "Ambiguous tool name")],
)
registered = _register_server_tools("a_b", server, {})
try:
assert registered == ["mcp_a_b_tool"]
with _lock:
assert _mcp_tool_server_names["mcp_a_b_tool"] == "a_b"
_parallel_safe_servers.add("a")
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
with _lock:
_parallel_safe_servers.add("a_b")
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is True
finally:
for tool_name in registered:
registry.deregister(tool_name)
with _lock:
_parallel_safe_servers.discard("a")
_parallel_safe_servers.discard("a_b")
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
def test_is_mcp_tool_parallel_safe_no_tool_suffix(self):
"""Tool name that is just 'mcp_{server}' without a tool part returns False."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names.pop("mcp_docs", None)
_mcp_tool_server_names.pop("mcp_docs_", None)
try:
# "mcp_docs" has no tool part after the server name
assert is_mcp_tool_parallel_safe("mcp_docs") is False