diff --git a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu index 80378c9f9af1a..937b246d10466 100644 --- a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu +++ b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/top_p_sampling_kernel.h" #ifdef PADDLE_WITH_HIP +#include #include #include #include @@ -28,6 +29,9 @@ #define CUDA_BFLOAT16_AVAILABLE #include #endif +#if defined(PADDLE_WITH_HIP) && HIP_VERSION >= 60100000 +#define HIP_BFLOAT16_AVAILABLE +#endif #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" @@ -65,6 +69,13 @@ struct DataTypeTraits { }; #endif +#ifdef HIP_BFLOAT16_AVAILABLE +template <> +struct DataTypeTraits { + using DataType = hip_bfloat16; +}; +#endif + #define FINAL_MASK 0xFFFFFFFF #define FIXED_BLOCK_DIM_BASE(dim, ...) \ @@ -1260,7 +1271,7 @@ void TopPSamplingKernel(const Context& dev_ctx, } // namespace phi -#ifdef CUDA_BFLOAT16_AVAILABLE +#if defined(CUDA_BFLOAT16_AVAILABLE) || defined(HIP_BFLOAT16_AVAILABLE) PD_REGISTER_KERNEL(top_p_sampling, GPU, ALL_LAYOUT, diff --git a/test/legacy_test/test_top_p_sampling.py b/test/legacy_test/test_top_p_sampling.py index 403e5a3d6ffa4..b30deba7da465 100644 --- a/test/legacy_test/test_top_p_sampling.py +++ b/test/legacy_test/test_top_p_sampling.py @@ -173,5 +173,31 @@ def test_static(self): self.run_static(place) +@unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()) + or not core.is_bfloat16_supported(get_device_place()), + "core is not compiled with CUDA or not support bfloat16", +) +class TestTopPAPIBF16(unittest.TestCase): + def test_dygraph_bfloat16(self): + with paddle.base.dygraph.guard(get_device_place()): + input_tensor = paddle.to_tensor( + [[0.6, 0.3, 0.1], [0.2, 0.5, 0.3]], dtype="float32" + ).astype("bfloat16") + topp_tensor = paddle.to_tensor( + [[0.8], [0.8]], dtype="float32" + ).astype("bfloat16") + + out, ids = paddle.tensor.top_p_sampling( + input_tensor, topp_tensor, seed=2023 + ) + + self.assertEqual(out.dtype, paddle.bfloat16) + self.assertEqual(ids.dtype, paddle.int64) + self.assertEqual(out.shape, [2, 1]) + self.assertEqual(ids.shape, [2, 1]) + ids.numpy() + + if __name__ == "__main__": unittest.main()