fix(mcp): prevent parallel-safe prefix collisions

This commit is contained in:
soynchux 2026-05-16 22:05:34 +03:00 committed by Teknium
parent 874dad5cc1
commit 280c63ce91
3 changed files with 124 additions and 20 deletions

View file

@ -1161,6 +1161,7 @@ class MCPServerTask:
}
for tool_name in stale_tool_names:
registry.deregister(tool_name)
_forget_mcp_tool_server(tool_name)
# 3. Re-register with fresh tool list
self._tools = new_mcp_tools
@ -1696,6 +1697,7 @@ class MCPServerTask:
self._pending_refresh_tasks.clear()
for tool_name in list(getattr(self, "_registered_tool_names", [])):
registry.deregister(tool_name)
_forget_mcp_tool_server(tool_name)
self._registered_tool_names = []
self.session = None
@ -2066,11 +2068,20 @@ def _handle_session_expired_and_retry(
# ``is_mcp_tool_parallel_safe()`` for the parallel-execution check in run_agent.
_parallel_safe_servers: set = set()
# Exact MCP tool-name provenance. MCP tool names are formatted as
# ``mcp_{sanitized_server}_{sanitized_tool}``, which is ambiguous when server
# names contain underscores (``mcp_a_b_tool`` could be server ``a`` + tool
# ``b_tool`` or server ``a_b`` + tool ``tool``). Keep the server component
# captured at registration time so parallel safety never relies on prefix
# guessing.
_mcp_tool_server_names: Dict[str, str] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers, and _stdio_pids.
# Protects _mcp_loop, _mcp_thread, _servers, _parallel_safe_servers,
# _mcp_tool_server_names, and _stdio_pids.
_lock = threading.Lock()
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
@ -2953,6 +2964,19 @@ _UTILITY_CAPABILITY_ATTRS = {
}
def _track_mcp_tool_server(tool_name: str, server_name: str) -> None:
"""Remember the exact MCP server that registered *tool_name*."""
safe_server_name = sanitize_mcp_name_component(server_name)
with _lock:
_mcp_tool_server_names[tool_name] = safe_server_name
def _forget_mcp_tool_server(tool_name: str) -> None:
"""Forget MCP server provenance for a deregistered tool."""
with _lock:
_mcp_tool_server_names.pop(tool_name, None)
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
"""Select utility schemas based on config and server capabilities."""
tools_filter = config.get("tools") or {}
@ -3087,6 +3111,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
is_async=False,
description=schema["description"],
)
_track_mcp_tool_server(tool_name_prefixed, name)
registered_names.append(tool_name_prefixed)
# Register MCP Resources & Prompts utility tools, filtered by config and
@ -3123,6 +3148,7 @@ def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> Li
is_async=False,
description=schema["description"],
)
_track_mcp_tool_server(util_name, name)
registered_names.append(util_name)
if registered_names:
@ -3307,24 +3333,19 @@ def discover_mcp_tools() -> List[str]:
def is_mcp_tool_parallel_safe(tool_name: str) -> bool:
"""Check if an MCP tool belongs to a server that supports parallel tool calls.
MCP tool names follow the pattern ``mcp_{server}_{tool}``. This extracts
the server component and checks it against the set of servers whose config
includes ``supports_parallel_tool_calls: true``.
MCP tool names follow the pattern ``mcp_{server}_{tool}``, but that string
shape is ambiguous when server names contain underscores. Use the exact
server provenance captured at registration time rather than prefix
matching, then check whether that server's config includes
``supports_parallel_tool_calls: true``.
Returns False for non-MCP tools or tools from servers without the flag.
"""
if not tool_name.startswith("mcp_"):
return False
# Strip the "mcp_" prefix and extract the server name.
# Tool names are: mcp_{sanitized_server}_{sanitized_tool}
# We need to check all possible server prefixes because the server name
# itself may contain underscores after sanitization.
rest = tool_name[4:] # strip "mcp_"
with _lock:
for server_name in _parallel_safe_servers:
if rest.startswith(server_name + "_") and len(rest) > len(server_name) + 1:
return True
return False
server_name = _mcp_tool_server_names.get(tool_name)
return bool(server_name and server_name in _parallel_safe_servers)
def get_mcp_status() -> List[dict]: