fix(mcp): prevent parallel-safe prefix collisions

This commit is contained in:
soynchux 2026-05-16 22:05:34 +03:00 committed by Teknium
parent 874dad5cc1
commit 280c63ce91
3 changed files with 124 additions and 20 deletions

View file

@ -2282,9 +2282,11 @@ class TestMcpParallelToolBatch:
def test_mcp_tools_parallel_when_server_opted_in(self):
"""MCP tools from a parallel-safe server can run concurrently."""
from run_agent import _should_parallelize_tool_batch
from tools.mcp_tool import _parallel_safe_servers, _lock
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
with _lock:
_parallel_safe_servers.add("github")
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
_mcp_tool_server_names["mcp_github_search_code"] = "github"
try:
tc1 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c1")
tc2 = _mock_tool_call(name="mcp_github_search_code", arguments='{"q":"test"}', call_id="c2")
@ -2292,13 +2294,16 @@ class TestMcpParallelToolBatch:
finally:
with _lock:
_parallel_safe_servers.discard("github")
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
_mcp_tool_server_names.pop("mcp_github_search_code", None)
def test_mixed_mcp_and_builtin_parallel(self):
"""MCP parallel tools mixed with built-in parallel-safe tools."""
from run_agent import _should_parallelize_tool_batch
from tools.mcp_tool import _parallel_safe_servers, _lock
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names["mcp_docs_search"] = "docs"
try:
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"test"}', call_id="c2")
@ -2306,14 +2311,17 @@ class TestMcpParallelToolBatch:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
def test_mixed_parallel_and_serial_mcp_servers(self):
"""One parallel MCP server + one non-parallel MCP server = sequential."""
from run_agent import _should_parallelize_tool_batch
from tools.mcp_tool import _parallel_safe_servers, _lock
from tools.mcp_tool import _mcp_tool_server_names, _parallel_safe_servers, _lock
with _lock:
_parallel_safe_servers.add("docs")
# "github" is NOT in _parallel_safe_servers
_mcp_tool_server_names["mcp_docs_search"] = "docs"
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
try:
tc1 = _mock_tool_call(name="mcp_docs_search", arguments='{"query":"api"}', call_id="c1")
tc2 = _mock_tool_call(name="mcp_github_list_repos", arguments='{"org":"openai"}', call_id="c2")
@ -2321,6 +2329,8 @@ class TestMcpParallelToolBatch:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
class TestHandleMaxIterations:

View file

@ -3781,16 +3781,26 @@ class TestMcpParallelToolCalls:
def test_is_mcp_tool_parallel_safe_no_servers(self):
"""MCP tool from unknown server returns False."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.clear()
_mcp_tool_server_names.clear()
assert is_mcp_tool_parallel_safe("mcp_docs_search") is False
def test_is_mcp_tool_parallel_safe_with_flag(self):
"""MCP tool from a parallel-safe server returns True."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names["mcp_docs_search"] = "docs"
_mcp_tool_server_names["mcp_docs_read_file"] = "docs"
_mcp_tool_server_names["mcp_github_list_repos"] = "github"
try:
assert is_mcp_tool_parallel_safe("mcp_docs_search") is True
assert is_mcp_tool_parallel_safe("mcp_docs_read_file") is True
@ -3799,23 +3809,86 @@ class TestMcpParallelToolCalls:
finally:
with _lock:
_parallel_safe_servers.discard("docs")
_mcp_tool_server_names.pop("mcp_docs_search", None)
_mcp_tool_server_names.pop("mcp_docs_read_file", None)
_mcp_tool_server_names.pop("mcp_github_list_repos", None)
def test_is_mcp_tool_parallel_safe_server_with_underscores(self):
"""Server names containing underscores are correctly matched."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("my_server")
_mcp_tool_server_names["mcp_my_server_query"] = "my_server"
try:
assert is_mcp_tool_parallel_safe("mcp_my_server_query") is True
finally:
with _lock:
_parallel_safe_servers.discard("my_server")
_mcp_tool_server_names.pop("mcp_my_server_query", None)
def test_is_mcp_tool_parallel_safe_uses_exact_registered_server(self):
"""Ambiguous MCP names must not match a shorter parallel-safe prefix."""
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("a")
_mcp_tool_server_names["mcp_a_search"] = "a"
_mcp_tool_server_names["mcp_a_b_tool"] = "a_b"
try:
assert is_mcp_tool_parallel_safe("mcp_a_search") is True
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
finally:
with _lock:
_parallel_safe_servers.discard("a")
_mcp_tool_server_names.pop("mcp_a_search", None)
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
def test_registered_tool_provenance_prevents_prefix_collision(self):
"""Registration records exact server ownership for ambiguous names."""
from tools.registry import registry
from tools.mcp_tool import (
_mcp_tool_server_names, _parallel_safe_servers,
_register_server_tools, is_mcp_tool_parallel_safe, _lock,
)
server = _make_mock_server(
"a_b",
tools=[_make_mcp_tool("tool", "Ambiguous tool name")],
)
registered = _register_server_tools("a_b", server, {})
try:
assert registered == ["mcp_a_b_tool"]
with _lock:
assert _mcp_tool_server_names["mcp_a_b_tool"] == "a_b"
_parallel_safe_servers.add("a")
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is False
with _lock:
_parallel_safe_servers.add("a_b")
assert is_mcp_tool_parallel_safe("mcp_a_b_tool") is True
finally:
for tool_name in registered:
registry.deregister(tool_name)
with _lock:
_parallel_safe_servers.discard("a")
_parallel_safe_servers.discard("a_b")
_mcp_tool_server_names.pop("mcp_a_b_tool", None)
def test_is_mcp_tool_parallel_safe_no_tool_suffix(self):
"""Tool name that is just 'mcp_{server}' without a tool part returns False."""
from tools.mcp_tool import is_mcp_tool_parallel_safe, _parallel_safe_servers, _lock
from tools.mcp_tool import (
is_mcp_tool_parallel_safe, _mcp_tool_server_names,
_parallel_safe_servers, _lock,
)
with _lock:
_parallel_safe_servers.add("docs")
_mcp_tool_server_names.pop("mcp_docs", None)
_mcp_tool_server_names.pop("mcp_docs_", None)
try:
# "mcp_docs" has no tool part after the server name
assert is_mcp_tool_parallel_safe("mcp_docs") is False

View file

@ -1161,6 +1161,7 @@ class MCPServerTask:
}
for tool_name in stale_tool_names:
registry.deregister(tool_name)
_forget_mcp_tool_server(tool_name)
# 3. Re-register with fresh tool list
self._tools = new_mcp_tools
@ -1696,6 +1697,7 @@ class MCPServerTask:
self._pending_refresh_tasks.clear()
for tool_name in list(getattr(self, "_registered_tool_names", [])):
registry.deregister(tool_name)
_forget_mcp_tool_server(tool_name)
self._registered_tool_names = []
self.session = None
@ -2066,11 +2068,20 @@ def _handle_session_expired_and_retry(
# ``is_mcp_tool_parallel_safe()`` for the parallel-execution check in run_agent.
_parallel_safe_servers: set = set()
# Exact MCP tool-name provenance. MCP tool names are formatted as
# ``mcp_{sanitized_server}_{sanitized_tool}``, which is ambiguous when server
# names contain underscores (``mcp_a_b_tool`` could be server ``a`` + tool
# ``b_tool`` or server ``a_b`` + tool ``tool``). Keep the server component
# captured at registration time so parallel safety never relies on prefix
# guessing.
_mcp_tool_server_names: Dict[str, str] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers, and _stdio_pids.
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers,
# _mcp_tool_server_names, and _stdio_pids.
_lock = threading.Lock()
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
@ -2953,6 +2964,19 @@ _UTILITY_CAPABILITY_ATTRS = {
}
def _track_mcp_tool_server(tool_name: str, server_name: str) -> None:
"""Remember the exact MCP server that registered *tool_name*."""
safe_server_name = sanitize_mcp_name_component(server_name)
with _lock:
_mcp_tool_server_names[tool_name] = safe_server_name
def _forget_mcp_tool_server(tool_name: str) -> None:
"""Forget MCP server provenance for a deregistered tool."""
with _lock:
_mcp_tool_server_names.pop(tool_name, None)
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
"""Select utility schemas based on config and server capabilities."""
tools_filter = config.get("tools") or {}
@ -3087,6 +3111,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
is_async=False,
description=schema["description"],
)
_track_mcp_tool_server(tool_name_prefixed, name)
registered_names.append(tool_name_prefixed)
# Register MCP Resources & Prompts utility tools, filtered by config and
@ -3123,6 +3148,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
is_async=False,
description=schema["description"],
)
_track_mcp_tool_server(util_name, name)
registered_names.append(util_name)
if registered_names:
@ -3307,24 +3333,19 @@ def discover_mcp_tools() -> List[str]:
def is_mcp_tool_parallel_safe(tool_name: str) -> bool:
"""Check if an MCP tool belongs to a server that supports parallel tool calls.
MCP tool names follow the pattern ``mcp_{server}_{tool}``. This extracts
the server component and checks it against the set of servers whose config
includes ``supports_parallel_tool_calls: true``.
MCP tool names follow the pattern ``mcp_{server}_{tool}``, but that string
shape is ambiguous when server names contain underscores. Use the exact
server provenance captured at registration time rather than prefix
matching, then check whether that server's config includes
``supports_parallel_tool_calls: true``.
Returns False for non-MCP tools or tools from servers without the flag.
"""
if not tool_name.startswith("mcp_"):
return False
# Strip the "mcp_" prefix and extract the server name.
# Tool names are: mcp_{sanitized_server}_{sanitized_tool}
# We need to check all possible server prefixes because the server name
# itself may contain underscores after sanitization.
rest = tool_name[4:] # strip "mcp_"
with _lock:
for server_name in _parallel_safe_servers:
if rest.startswith(server_name + "_") and len(rest) > len(server_name) + 1:
return True
return False
server_name = _mcp_tool_server_names.get(tool_name)
return bool(server_name and server_name in _parallel_safe_servers)
def get_mcp_status() -> List[dict]: