diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 377714ef9..3796d8ced 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -2450,76 +2450,226 @@ class TestDiscoveryFailedCount: class TestMCPSelectiveToolLoading: - """Tests for tools.include / tools.exclude / enabled config keys.""" + """Tests for per-server MCP filtering and utility tool policies.""" - def _make_server(self, name, tool_names): - from tools.mcp_tool import MCPServerTask - server = MCPServerTask(name) - server.session = MagicMock() - server._tools = [_make_mcp_tool(n, n) for n in tool_names] + def _make_server(self, name, tool_names, session=None): + server = _make_mock_server( + name, + session=session or SimpleNamespace(), + tools=[_make_mcp_tool(n, n) for n in tool_names], + ) return server - def _run_discover(self, name, tool_names, config): - """Run _discover_and_register_server directly and return registered names.""" - import asyncio - from tools.mcp_tool import _discover_and_register_server - server = self._make_server(name, tool_names) + def _run_discover(self, name, tool_names, config, session=None): + from tools.registry import ToolRegistry + from tools.mcp_tool import _discover_and_register_server, _servers - async def fake_connect(n, c): + mock_registry = ToolRegistry() + server = self._make_server(name, tool_names, session=session) + + async def fake_connect(_name, _config): return server async def run(): - with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), patch("tools.mcp_tool._servers", {}): + with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset"): return await _discover_and_register_server(name, config) - return asyncio.run(run()) + try: + registered = asyncio.run(run()) + finally: + _servers.pop(name, None) + return registered, mock_registry - def test_include_filter_registers_only_listed_tools(self): - """tools.include whitelist: only specified tools are registered.""" - tool_names = ["create_service", "delete_service", "list_services"] - config = {"url": "https://mcp.example.com", "tools": {"include": ["create_service", "list_services"]}} - result = self._run_discover("ink", tool_names, config) - assert "mcp_ink_create_service" in result - assert "mcp_ink_list_services" in result - assert "mcp_ink_delete_service" not in result - - def test_exclude_filter_skips_listed_tools(self): - """tools.exclude blacklist: all tools except specified are registered.""" - tool_names = ["create_service", "delete_service", "list_services"] - config = {"url": "https://mcp.example.com", "tools": {"exclude": ["delete_service"]}} - result = self._run_discover("ink2", tool_names, config) - assert "mcp_ink2_create_service" in result - assert "mcp_ink2_list_services" in result - assert "mcp_ink2_delete_service" not in result - - def test_no_filter_registers_all_tools(self): - """No tools filter: all tools registered (backward compatible).""" - tool_names = ["create_service", "delete_service", "list_services"] - config = {"url": "https://mcp.example.com"} - result = self._run_discover("ink3", tool_names, config) - assert "mcp_ink3_create_service" in result - assert "mcp_ink3_delete_service" in result - assert "mcp_ink3_list_services" in result - - def test_enabled_false_skips_server(self): - """enabled: false skips the server entirely.""" - fresh_servers = {} - fake_config = { - "ink": { - "url": "https://mcp.example.com", - "enabled": False, - } + def test_include_takes_precedence_over_exclude(self): + config = { + "url": "https://mcp.example.com", + "tools": { + "include": ["create_service"], + "exclude": ["create_service", "delete_service"], + }, } + registered, _ = self._run_discover( + "ink", + ["create_service", "delete_service", "list_services"], + config, + session=SimpleNamespace(), + ) + assert registered == ["mcp_ink_create_service"] + + def test_exclude_filter_registers_all_except_listed_tools(self): + config = { + "url": "https://mcp.example.com", + "tools": {"exclude": ["delete_service"]}, + } + registered, _ = self._run_discover( + "ink_exclude", + ["create_service", "delete_service", "list_services"], + config, + session=SimpleNamespace(), + ) + assert registered == [ + "mcp_ink_exclude_create_service", + "mcp_ink_exclude_list_services", + ] + + def test_include_filter_skips_utility_tools_without_capabilities(self): + config = { + "url": "https://mcp.example.com", + "tools": {"include": ["create_service"]}, + } + registered, mock_registry = self._run_discover( + "ink_no_caps", + ["create_service", "delete_service"], + config, + session=SimpleNamespace(), + ) + assert registered == ["mcp_ink_no_caps_create_service"] + assert set(mock_registry.get_all_tool_names()) == {"mcp_ink_no_caps_create_service"} + + def test_no_filter_registers_all_server_tools_when_no_utilities_supported(self): + registered, _ = self._run_discover( + "ink_no_filter", + ["create_service", "delete_service", "list_services"], + {"url": "https://mcp.example.com"}, + session=SimpleNamespace(), + ) + assert registered == [ + "mcp_ink_no_filter_create_service", + "mcp_ink_no_filter_delete_service", + "mcp_ink_no_filter_list_services", + ] + + def test_resources_and_prompts_can_be_disabled_explicitly(self): + session = SimpleNamespace( + list_resources=AsyncMock(), + read_resource=AsyncMock(), + list_prompts=AsyncMock(), + get_prompt=AsyncMock(), + ) + config = { + "url": "https://mcp.example.com", + "tools": { + "resources": False, + "prompts": False, + }, + } + registered, _ = self._run_discover( + "ink_disabled_utils", + ["create_service"], + config, + session=session, + ) + assert registered == ["mcp_ink_disabled_utils_create_service"] + + def test_registers_only_utility_tools_supported_by_server_capabilities(self): + session = SimpleNamespace( + list_resources=AsyncMock(return_value=SimpleNamespace(resources=[])), + read_resource=AsyncMock(return_value=SimpleNamespace(contents=[])), + ) + registered, _ = self._run_discover( + "ink_resources_only", + ["create_service"], + {"url": "https://mcp.example.com"}, + session=session, + ) + assert "mcp_ink_resources_only_create_service" in registered + assert "mcp_ink_resources_only_list_resources" in registered + assert "mcp_ink_resources_only_read_resource" in registered + assert "mcp_ink_resources_only_list_prompts" not in registered + assert "mcp_ink_resources_only_get_prompt" not in registered + + def test_existing_tool_names_reflect_registered_subset(self): + from tools.mcp_tool import _existing_tool_names, _servers, _discover_and_register_server + from tools.registry import ToolRegistry + + mock_registry = ToolRegistry() + server = self._make_server( + "ink_existing", + ["create_service", "delete_service"], + session=SimpleNamespace(), + ) + + async def fake_connect(_name, _config): + return server + + async def run(): + with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset"): + return await _discover_and_register_server( + "ink_existing", + {"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}}, + ) + + try: + registered = asyncio.run(run()) + assert registered == ["mcp_ink_existing_create_service"] + assert _existing_tool_names() == ["mcp_ink_existing_create_service"] + finally: + _servers.pop("ink_existing", None) + + def test_no_toolset_created_when_everything_is_filtered_out(self): + from tools.registry import ToolRegistry + from tools.mcp_tool import _discover_and_register_server, _servers + + mock_registry = ToolRegistry() + server = self._make_server("ink_none", ["create_service"], session=SimpleNamespace()) + mock_create = MagicMock() + + async def fake_connect(_name, _config): + return server + + async def run(): + with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset", mock_create): + return await _discover_and_register_server( + "ink_none", + { + "url": "https://mcp.example.com", + "tools": { + "include": ["missing_tool"], + "resources": False, + "prompts": False, + }, + }, + ) + + try: + registered = asyncio.run(run()) + assert registered == [] + mock_create.assert_not_called() + assert mock_registry.get_all_tool_names() == [] + finally: + _servers.pop("ink_none", None) + + def test_enabled_false_skips_connection_attempt(self): + from tools.mcp_tool import discover_mcp_tools + connect_called = [] async def fake_connect(name, config): connect_called.append(name) return self._make_server(name, ["create_service"]) - with patch("tools.mcp_tool._MCP_AVAILABLE", True), patch("tools.mcp_tool._servers", fresh_servers), patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), patch("tools.mcp_tool._connect_server", side_effect=fake_connect): - from tools.mcp_tool import discover_mcp_tools + fake_config = { + "ink": { + "url": "https://mcp.example.com", + "enabled": False, + } + } + fake_toolsets = { + "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, + } + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._servers", {}), \ + patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("toolsets.TOOLSETS", fake_toolsets): result = discover_mcp_tools() assert connect_called == [] - assert "mcp_ink_create_service" not in result - + assert result == [] diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index a8deb3ae1..7294e8be5 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -688,7 +688,7 @@ class MCPServerTask: __slots__ = ( "name", "session", "tool_timeout", "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", - "_sampling", + "_sampling", "_registered_tool_names", ) def __init__(self, name: str): @@ -702,6 +702,7 @@ class MCPServerTask: self._error: Optional[Exception] = None self._config: dict = {} self._sampling: Optional[SamplingHandler] = None + self._registered_tool_names: list[str] = [] def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" @@ -1308,16 +1309,81 @@ def _build_utility_schemas(server_name: str) -> List[dict]: ] +def _normalize_name_filter(value: Any, label: str) -> set[str]: + """Normalize include/exclude config to a set of tool names.""" + if value is None: + return set() + if isinstance(value, str): + return {value} + if isinstance(value, (list, tuple, set)): + return {str(item) for item in value} + logger.warning("MCP config %s must be a string or list of strings; ignoring %r", label, value) + return set() + + +def _parse_boolish(value: Any, default: bool = True) -> bool: + """Parse a bool-like config value with safe fallback.""" + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"true", "1", "yes", "on"}: + return True + if lowered in {"false", "0", "no", "off"}: + return False + logger.warning("MCP config expected a boolean-ish value, got %r; using default=%s", value, default) + return default + + +_UTILITY_CAPABILITY_METHODS = { + "list_resources": "list_resources", + "read_resource": "read_resource", + "list_prompts": "list_prompts", + "get_prompt": "get_prompt", +} + + +def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]: + """Select utility schemas based on config and server capabilities.""" + tools_filter = config.get("tools") or {} + resources_enabled = _parse_boolish(tools_filter.get("resources"), default=True) + prompts_enabled = _parse_boolish(tools_filter.get("prompts"), default=True) + + selected: List[dict] = [] + for entry in _build_utility_schemas(server_name): + handler_key = entry["handler_key"] + if handler_key in {"list_resources", "read_resource"} and not resources_enabled: + logger.debug("MCP server '%s': skipping utility '%s' (resources disabled)", server_name, handler_key) + continue + if handler_key in {"list_prompts", "get_prompt"} and not prompts_enabled: + logger.debug("MCP server '%s': skipping utility '%s' (prompts disabled)", server_name, handler_key) + continue + + required_method = _UTILITY_CAPABILITY_METHODS[handler_key] + if not hasattr(server.session, required_method): + logger.debug( + "MCP server '%s': skipping utility '%s' (session lacks %s)", + server_name, + handler_key, + required_method, + ) + continue + selected.append(entry) + return selected + + def _existing_tool_names() -> List[str]: """Return tool names for all currently connected servers.""" names: List[str] = [] - for sname, server in _servers.items(): + for _sname, server in _servers.items(): + if hasattr(server, "_registered_tool_names"): + names.extend(server._registered_tool_names) + continue for mcp_tool in server._tools: - schema = _convert_mcp_schema(sname, mcp_tool) + schema = _convert_mcp_schema(server.name, mcp_tool) names.append(schema["name"]) - # Also include utility tool names - for entry in _build_utility_schemas(sname): - names.append(entry["schema"]["name"]) return names @@ -1347,11 +1413,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: # Rules (matching issue #690 spec): # tools.include — whitelist: only these tool names are registered # tools.exclude — blacklist: all tools EXCEPT these are registered - # include and exclude are mutually exclusive; include takes precedence + # include takes precedence over exclude # Neither set → register all tools (backward-compatible default) tools_filter = config.get("tools") or {} - include_set = set(tools_filter.get("include") or []) - exclude_set = set(tools_filter.get("exclude") or []) + include_set = _normalize_name_filter(tools_filter.get("include"), f"mcp_servers.{name}.tools.include") + exclude_set = _normalize_name_filter(tools_filter.get("exclude"), f"mcp_servers.{name}.tools.exclude") def _should_register(tool_name: str) -> bool: if include_set: @@ -1378,7 +1444,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: ) registered_names.append(tool_name_prefixed) - # Register MCP Resources & Prompts utility tools + # Register MCP Resources & Prompts utility tools, filtered by config and + # only when the server actually supports the corresponding capability. _handler_factories = { "list_resources": _make_list_resources_handler, "read_resource": _make_read_resource_handler, @@ -1386,7 +1453,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: "get_prompt": _make_get_prompt_handler, } check_fn = _make_check_fn(name) - for entry in _build_utility_schemas(name): + for entry in _select_utility_schemas(name, server, config): schema = entry["schema"] handler_key = entry["handler_key"] handler = _handler_factories[handler_key](name, server.tool_timeout) @@ -1402,6 +1469,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: ) registered_names.append(schema["name"]) + server._registered_tool_names = list(registered_names) + # Create a custom toolset so these tools are discoverable if registered_names: create_custom_toolset( @@ -1448,8 +1517,9 @@ def discover_mcp_tools() -> List[str]: # (enabled: false skips the server entirely without removing its config) with _lock: new_servers = { - k: v for k, v in servers.items() - if k not in _servers and v.get("enabled", True) is not False + k: v + for k, v in servers.items() + if k not in _servers and _parse_boolish(v.get("enabled", True), default=True) } if not new_servers: @@ -1537,7 +1607,7 @@ def get_mcp_status() -> List[dict]: entry = { "name": name, "transport": transport, - "tools": len(server._tools), + "tools": len(server._registered_tool_names) if hasattr(server, "_registered_tool_names") else len(server._tools), "connected": True, } if server._sampling: