Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/top_p_sampling_kernel.h"

#ifdef PADDLE_WITH_HIP
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
Expand All @@ -28,6 +29,9 @@
#define CUDA_BFLOAT16_AVAILABLE
#include <cuda_bf16.h>
#endif
#if defined(PADDLE_WITH_HIP) && HIP_VERSION >= 60100000
#define HIP_BFLOAT16_AVAILABLE
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
Expand Down Expand Up @@ -65,6 +69,13 @@ struct DataTypeTraits<phi::bfloat16> {
};
#endif

#ifdef HIP_BFLOAT16_AVAILABLE
template <>
struct DataTypeTraits<phi::bfloat16> {
using DataType = hip_bfloat16;
};
#endif

#define FINAL_MASK 0xFFFFFFFF

#define FIXED_BLOCK_DIM_BASE(dim, ...) \
Expand Down Expand Up @@ -1260,7 +1271,7 @@ void TopPSamplingKernel(const Context& dev_ctx,

} // namespace phi

#ifdef CUDA_BFLOAT16_AVAILABLE
#if defined(CUDA_BFLOAT16_AVAILABLE) || defined(HIP_BFLOAT16_AVAILABLE)
PD_REGISTER_KERNEL(top_p_sampling,
GPU,
ALL_LAYOUT,
Expand Down
26 changes: 26 additions & 0 deletions test/legacy_test/test_top_p_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,31 @@ def test_static(self):
self.run_static(place)


@unittest.skipIf(
not (core.is_compiled_with_cuda() or is_custom_device())
or not core.is_bfloat16_supported(get_device_place()),
"core is not compiled with CUDA or not support bfloat16",
)
class TestTopPAPIBF16(unittest.TestCase):
def test_dygraph_bfloat16(self):
with paddle.base.dygraph.guard(get_device_place()):
input_tensor = paddle.to_tensor(
[[0.6, 0.3, 0.1], [0.2, 0.5, 0.3]], dtype="float32"
).astype("bfloat16")
topp_tensor = paddle.to_tensor(
[[0.8], [0.8]], dtype="float32"
).astype("bfloat16")

out, ids = paddle.tensor.top_p_sampling(
input_tensor, topp_tensor, seed=2023
)

self.assertEqual(out.dtype, paddle.bfloat16)
self.assertEqual(ids.dtype, paddle.int64)
self.assertEqual(out.shape, [2, 1])
self.assertEqual(ids.shape, [2, 1])
ids.numpy()


if __name__ == "__main__":
unittest.main()