diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 9ae54dafbc7406..23ab198e297940 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -473,12 +473,7 @@ void BindPlace(pybind11::module &m) { // NOLINT #endif }); m.def("is_bfloat16_supported", [](const GPUPlace &place) -> bool { - // Only GPUs with Compute Capability >= 80 support bfloat16 -#ifdef PADDLE_WITH_HIP - return true; -#else - return platform::GetGPUComputeCapability(place.device) >= 80; -#endif + return platform::IsBFloat16Supported(place.device); }); #endif py::class_ xpuplace(m, "XPUPlace", R"DOC( diff --git a/paddle/phi/backends/gpu/cuda/cuda_info.cc b/paddle/phi/backends/gpu/cuda/cuda_info.cc index 99c9eb6ed0feec..658def8b74aa27 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_info.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_info.cc @@ -131,6 +131,13 @@ bool TensorCoreAvailable() { return driver_version >= 70; } +bool IsBFloat16Supported(int id) { + if (id == -1) { + id = GetCurrentDeviceId(); + } + return GetGPUComputeCapability(id) >= 80 && GetGPURuntimeVersion(id) >= 11000; +} + int GetGPUMultiProcessors(int id) { PADDLE_ENFORCE_LT(id, GetGPUDeviceCount(), diff --git a/paddle/phi/backends/gpu/gpu_info.h b/paddle/phi/backends/gpu/gpu_info.h index 73f167d7e865af..f520210524476b 100644 --- a/paddle/phi/backends/gpu/gpu_info.h +++ b/paddle/phi/backends/gpu/gpu_info.h @@ -45,6 +45,9 @@ int GetGPUDriverVersion(int id); //! Whether the current device support TensorCore PADDLE_API bool TensorCoreAvailable(); +//! Whether the ith GPU supports native bfloat16. +PADDLE_API bool IsBFloat16Supported(int id); + //! Get the MultiProcessors of the ith GPU. int GetGPUMultiProcessors(int id); diff --git a/paddle/phi/backends/gpu/rocm/rocm_info.cc b/paddle/phi/backends/gpu/rocm/rocm_info.cc index 8c65390463b2a4..0d2aa6ab222ae7 100644 --- a/paddle/phi/backends/gpu/rocm/rocm_info.cc +++ b/paddle/phi/backends/gpu/rocm/rocm_info.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "paddle/phi/backends/gpu/gpu_info.h" @@ -129,6 +130,16 @@ int GetGPUDriverVersion(int id) { bool TensorCoreAvailable() { return false; } +bool IsBFloat16Supported(int id) { + if (id == -1) { + id = GetCurrentDeviceId(); + } + const auto &prop = GetDeviceProperties(id); + std::string arch(prop.gcnArchName); + return arch.rfind("gfx90a", 0) == 0 || arch.rfind("gfx94", 0) == 0 || + arch.rfind("gfx95", 0) == 0; +} + int GetGPUMultiProcessors(int id) { PADDLE_ENFORCE_LT(id, GetGPUDeviceCount(), diff --git a/paddle/phi/core/platform/device/gpu/gpu_info.cc b/paddle/phi/core/platform/device/gpu/gpu_info.cc index 143eeb43f686cb..765b612261e150 100644 --- a/paddle/phi/core/platform/device/gpu/gpu_info.cc +++ b/paddle/phi/core/platform/device/gpu/gpu_info.cc @@ -649,6 +649,10 @@ int GetGPUDriverVersion(int id) { bool TensorCoreAvailable() { return phi::backends::gpu::TensorCoreAvailable(); } +bool IsBFloat16Supported(int id) { + return phi::backends::gpu::IsBFloat16Supported(id); +} + int GetGPUMultiProcessors(int id) { return phi::backends::gpu::GetGPUMultiProcessors(id); } diff --git a/paddle/phi/core/platform/device/gpu/gpu_info.h b/paddle/phi/core/platform/device/gpu/gpu_info.h index df352c97f506ee..6b1f3bdbdc8cb0 100644 --- a/paddle/phi/core/platform/device/gpu/gpu_info.h +++ b/paddle/phi/core/platform/device/gpu/gpu_info.h @@ -46,6 +46,9 @@ int GetGPUDriverVersion(int id); //! Whether the current device support TensorCore bool TensorCoreAvailable(); +//! Whether the ith GPU supports native bfloat16. +PADDLE_API bool IsBFloat16Supported(int id); + //! Get the MultiProcessors of the ith GPU. int GetGPUMultiProcessors(int id); diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 4792d7fd04c96e..dae640e5e98e6b 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -190,13 +190,7 @@ def _is_gpu_bfloat16_supported() -> bool: """ Judge whether current gpu support bfloat16 amp. """ - prop = paddle.device.cuda.get_device_capability() - cuda_version = paddle.version.cuda() - if cuda_version is not None and cuda_version != 'False': - cuda_version_check = int(cuda_version.split('.')[0]) >= 11 - else: - cuda_version_check = False - return prop[0] >= 8 and cuda_version_check or paddle.is_compiled_with_rocm() + return core.is_bfloat16_supported(_current_expected_place()) def _is_xpu_float16_supported() -> bool: @@ -636,11 +630,16 @@ def amp_guard( ) enable = False elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): - prop = paddle.device.cuda.get_device_capability() - cuda_version = paddle.version.cuda() - warnings.warn( - f"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: {paddle.device.cuda.get_device_name()}, with Compute Capability: {prop[0]}.{prop[1]}, current CUDA Version is: {cuda_version}." - ) + if paddle.is_compiled_with_rocm(): + warnings.warn( + f"For bfloat16, amp only supports ROCm GPUs with native bfloat16 support, current GPU is: {paddle.device.cuda.get_device_name()}." + ) + else: + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + warnings.warn( + f"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: {paddle.device.cuda.get_device_name()}, with Compute Capability: {prop[0]}.{prop[1]}, current CUDA Version is: {cuda_version}." + ) enable = False if not enable: diff --git a/test/amp/test_layer_convert_dtype.py b/test/amp/test_layer_convert_dtype.py index 2ea1d3fde27912..91af9e2c3e41f6 100644 --- a/test/amp/test_layer_convert_dtype.py +++ b/test/amp/test_layer_convert_dtype.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import unittest +from unittest.mock import patch import paddle import paddle.nn.functional as F @@ -178,6 +180,30 @@ def test_gpu_bf16_supported(self): res = paddle.amp.is_bfloat16_supported('gpu') self.assertEqual(res, True) + def test_auto_cast_gpu_bf16_uses_core_support(self): + auto_cast_module = importlib.import_module('paddle.amp.auto_cast') + expected_place = object() + with ( + patch.object( + auto_cast_module, + '_current_expected_place', + return_value=expected_place, + ), + patch.object( + auto_cast_module.core, + 'is_bfloat16_supported', + return_value=False, + ) as support_check, + patch.object( + auto_cast_module.paddle, + 'is_compiled_with_rocm', + return_value=True, + ), + ): + self.assertFalse(auto_cast_module._is_gpu_bfloat16_supported()) + + support_check.assert_called_once_with(expected_place) + def test_device_value_error(self): self.assertRaises( ValueError, paddle.amp.is_float16_supported, device='xxx'