From 01f76629c031a47183dd02036d88d8fb6a1ebf11 Mon Sep 17 00:00:00 2001 From: austin1997 Date: Tue, 23 Jun 2026 04:50:55 +0000 Subject: [PATCH] [ROCm] Use native HIP BF16 rounding in phi bfloat16 --- paddle/common/backend_header.h | 5 ++++ paddle/phi/common/bfloat16.h | 48 ++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/paddle/common/backend_header.h b/paddle/common/backend_header.h index 76e56f1d2e461..fe0f9c076109e 100644 --- a/paddle/common/backend_header.h +++ b/paddle/common/backend_header.h @@ -23,6 +23,11 @@ #include #endif +#if defined(PADDLE_WITH_HIP) && defined(__HIPCC__) +#define PADDLE_HIP_BF16 +#include +#endif + #ifndef PADDLE_WITH_HIP #if !defined(_WIN32) #define PADDLE_ALIGN(x) __attribute__((aligned(x))) diff --git a/paddle/phi/common/bfloat16.h b/paddle/phi/common/bfloat16.h index 0405c8904ac49..4de66e847d309 100644 --- a/paddle/phi/common/bfloat16.h +++ b/paddle/phi/common/bfloat16.h @@ -82,21 +82,14 @@ struct PADDLE_ALIGN(2) bfloat16 { ~bfloat16() = default; HOSTDEVICE inline explicit bfloat16(float val) { -#ifdef PADDLE_WITH_HIP - uint32_t res = 0; - uint32_t* tempRes; - // We should be using memcpy in order to respect the strict aliasing rule - // but it fails in the HIP environment. - tempRes = reinterpret_cast(&val); - res = *tempRes; - x = res >> 16; -#else #if defined(PADDLE_CUDA_BF16) __nv_bfloat16 tmp = __float2bfloat16(val); x = *reinterpret_cast(&tmp); +#elif defined(PADDLE_HIP_BF16) + hip_bfloat16 tmp(val); + x = tmp.data; #else x = cpu_float_to_bfloat16(val); -#endif #endif } @@ -106,6 +99,10 @@ struct PADDLE_ALIGN(2) bfloat16 { } #endif +#if defined(PADDLE_HIP_BF16) + HOSTDEVICE inline explicit bfloat16(const hip_bfloat16& val) { x = val.data; } +#endif + template HOSTDEVICE inline explicit bfloat16(const T& val) : x(bfloat16(static_cast(val)).x) {} @@ -118,6 +115,13 @@ struct PADDLE_ALIGN(2) bfloat16 { } #endif +#if defined(PADDLE_HIP_BF16) + HOSTDEVICE inline bfloat16& operator=(const hip_bfloat16& val) { + x = val.data; + return *this; + } +#endif + HOSTDEVICE inline bfloat16& operator=(bool b) { x = b ? 0x3f80 : 0; return *this; @@ -175,26 +179,16 @@ struct PADDLE_ALIGN(2) bfloat16 { // Conversion operators HOSTDEVICE inline operator float() const { -#ifdef PADDLE_WITH_HIP - uint32_t res = 0; - // We should be using memcpy in order to respect the strict aliasing rule - // but it fails in the HIP environment. - uint16_t temp = x; - uint16_t* temp_ptr = reinterpret_cast(&temp); - res = *temp_ptr; - // return res; - res = res << 16; - return *reinterpret_cast(&res); -#else -#ifdef PADDLE_CUDA_BF16 +#if defined(PADDLE_CUDA_BF16) return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(PADDLE_HIP_BF16) + return static_cast(to_hip_bfloat16()); #else float val = 0.f; uint16_t temp = x; std::memcpy( reinterpret_cast(&val) + 2, reinterpret_cast(&temp), 2); return val; -#endif #endif } @@ -204,6 +198,14 @@ struct PADDLE_ALIGN(2) bfloat16 { } #endif +#ifdef PADDLE_HIP_BF16 + HOSTDEVICE inline hip_bfloat16 to_hip_bfloat16() const { + hip_bfloat16 val; + val.data = x; + return val; + } +#endif + HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; } HOSTDEVICE inline explicit operator int8_t() const {