Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions paddle/fluid/pybind/place.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,7 @@ void BindPlace(pybind11::module &m) { // NOLINT
#endif
});
m.def("is_bfloat16_supported", [](const GPUPlace &place) -> bool {
// Only GPUs with Compute Capability >= 80 support bfloat16
#ifdef PADDLE_WITH_HIP
return true;
#else
return platform::GetGPUComputeCapability(place.device) >= 80;
#endif
return platform::IsBFloat16Supported(place.device);
});
#endif
py::class_<phi::XPUPlace, phi::Place> xpuplace(m, "XPUPlace", R"DOC(
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/backends/gpu/cuda/cuda_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ bool TensorCoreAvailable() {
return driver_version >= 70;
}

bool IsBFloat16Supported(int id) {
if (id == -1) {
id = GetCurrentDeviceId();
}
return GetGPUComputeCapability(id) >= 80 && GetGPURuntimeVersion(id) >= 11000;
}

int GetGPUMultiProcessors(int id) {
PADDLE_ENFORCE_LT(id,
GetGPUDeviceCount(),
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/backends/gpu/gpu_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ int GetGPUDriverVersion(int id);
//! Whether the current device support TensorCore
PADDLE_API bool TensorCoreAvailable();

//! Whether the ith GPU supports native bfloat16.
PADDLE_API bool IsBFloat16Supported(int id);

//! Get the MultiProcessors of the ith GPU.
int GetGPUMultiProcessors(int id);

Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/backends/gpu/rocm/rocm_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <array>
#include <string>

#include "paddle/phi/backends/gpu/gpu_info.h"

Expand Down Expand Up @@ -129,6 +130,16 @@ int GetGPUDriverVersion(int id) {

bool TensorCoreAvailable() { return false; }

bool IsBFloat16Supported(int id) {
if (id == -1) {
id = GetCurrentDeviceId();
}
const auto &prop = GetDeviceProperties(id);
std::string arch(prop.gcnArchName);
return arch.rfind("gfx90a", 0) == 0 || arch.rfind("gfx94", 0) == 0 ||
arch.rfind("gfx95", 0) == 0;
}

int GetGPUMultiProcessors(int id) {
PADDLE_ENFORCE_LT(id,
GetGPUDeviceCount(),
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/core/platform/device/gpu/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,10 @@ int GetGPUDriverVersion(int id) {

bool TensorCoreAvailable() { return phi::backends::gpu::TensorCoreAvailable(); }

bool IsBFloat16Supported(int id) {
return phi::backends::gpu::IsBFloat16Supported(id);
}

int GetGPUMultiProcessors(int id) {
return phi::backends::gpu::GetGPUMultiProcessors(id);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/core/platform/device/gpu/gpu_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ int GetGPUDriverVersion(int id);
//! Whether the current device support TensorCore
bool TensorCoreAvailable();

//! Whether the ith GPU supports native bfloat16.
PADDLE_API bool IsBFloat16Supported(int id);

//! Get the MultiProcessors of the ith GPU.
int GetGPUMultiProcessors(int id);

Expand Down
23 changes: 11 additions & 12 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,7 @@ def _is_gpu_bfloat16_supported() -> bool:
"""
Judge whether current gpu support bfloat16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None and cuda_version != 'False':
cuda_version_check = int(cuda_version.split('.')[0]) >= 11
else:
cuda_version_check = False
return prop[0] >= 8 and cuda_version_check or paddle.is_compiled_with_rocm()
return core.is_bfloat16_supported(_current_expected_place())


def _is_xpu_float16_supported() -> bool:
Expand Down Expand Up @@ -636,11 +630,16 @@ def amp_guard(
)
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
f"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: {paddle.device.cuda.get_device_name()}, with Compute Capability: {prop[0]}.{prop[1]}, current CUDA Version is: {cuda_version}."
)
if paddle.is_compiled_with_rocm():
warnings.warn(
f"For bfloat16, amp only supports ROCm GPUs with native bfloat16 support, current GPU is: {paddle.device.cuda.get_device_name()}."
)
else:
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
f"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: {paddle.device.cuda.get_device_name()}, with Compute Capability: {prop[0]}.{prop[1]}, current CUDA Version is: {cuda_version}."
)
enable = False

if not enable:
Expand Down
26 changes: 26 additions & 0 deletions test/amp/test_layer_convert_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import unittest
from unittest.mock import patch

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -178,6 +180,30 @@ def test_gpu_bf16_supported(self):
res = paddle.amp.is_bfloat16_supported('gpu')
self.assertEqual(res, True)

def test_auto_cast_gpu_bf16_uses_core_support(self):
auto_cast_module = importlib.import_module('paddle.amp.auto_cast')
expected_place = object()
with (
patch.object(
auto_cast_module,
'_current_expected_place',
return_value=expected_place,
),
patch.object(
auto_cast_module.core,
'is_bfloat16_supported',
return_value=False,
) as support_check,
patch.object(
auto_cast_module.paddle,
'is_compiled_with_rocm',
return_value=True,
),
):
self.assertFalse(auto_cast_module._is_gpu_bfloat16_supported())

support_check.assert_called_once_with(expected_place)

def test_device_value_error(self):
self.assertRaises(
ValueError, paddle.amp.is_float16_supported, device='xxx'
Expand Down