diff --git a/paddle/cinn/backends/hip/codegen_hip_dev.cc b/paddle/cinn/backends/hip/codegen_hip_dev.cc index 9e0a15652c963a..05568ace3ad4a1 100644 --- a/paddle/cinn/backends/hip/codegen_hip_dev.cc +++ b/paddle/cinn/backends/hip/codegen_hip_dev.cc @@ -20,7 +20,12 @@ namespace hip { const std::string CodeGenHipDevice::source_header_ = // NOLINT R"(#define CINN_WITH_HIP + #include "bfloat16.h" #include "float16.h" + using cinn::common::bfloat16; + using cinn::common::bfloat168; + using cinn::common::bfloat164; + using cinn::common::bfloat162; using cinn::common::float16; #include "cinn_hip_runtime_source.h" )"; diff --git a/paddle/cinn/common/bfloat16.h b/paddle/cinn/common/bfloat16.h index 05e27da8fb0931..a311cd33c9ae24 100644 --- a/paddle/cinn/common/bfloat16.h +++ b/paddle/cinn/common/bfloat16.h @@ -34,6 +34,15 @@ #endif // __CUDACC__ #endif // CINN_WITH_CUDA +#ifdef CINN_WITH_HIP +#include +#if defined(__HIPCC__) && HIP_VERSION >= 60100000 +#define __HIP_PLATFORM_AMD__ +#include +#define CINN_HIP_BF16 +#endif +#endif // CINN_WITH_HIP + #ifdef __cplusplus #ifndef _WIN32 @@ -80,6 +89,9 @@ struct CINN_ALIGN(2) bfloat16 { #if defined(CINN_CUDA_BF16) __nv_bfloat16 tmp = __float2bfloat16(val); x = *reinterpret_cast(&tmp); +#elif defined(CINN_HIP_BF16) + hip_bfloat16 tmp(val); + x = *reinterpret_cast(&tmp); #else std::memcpy(&x, reinterpret_cast(&val) + 2, 2); #endif @@ -91,6 +103,12 @@ struct CINN_ALIGN(2) bfloat16 { } #endif +#if defined(CINN_HIP_BF16) + __host__ __device__ inline explicit bfloat16(const hip_bfloat16& val) { + x = *reinterpret_cast(&val); + } +#endif + template __host__ __device__ inline explicit bfloat16(const T& val) : x(bfloat16(static_cast(val)).x) {} @@ -103,6 +121,13 @@ struct CINN_ALIGN(2) bfloat16 { } #endif +#if defined(CINN_HIP_BF16) + __host__ __device__ inline bfloat16& operator=(const hip_bfloat16& val) { + x = *reinterpret_cast(&val); + return *this; + } +#endif + __host__ __device__ inline bfloat16& operator=(bool b) { x = b ? 0x3f80 : 0; return *this; @@ -162,6 +187,8 @@ struct CINN_ALIGN(2) bfloat16 { __host__ __device__ inline operator float() const { #ifdef CINN_CUDA_BF16 return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(CINN_HIP_BF16) + return static_cast(*reinterpret_cast(&x)); #else float val = 0.f; uint16_t temp = x; @@ -177,6 +204,12 @@ struct CINN_ALIGN(2) bfloat16 { } #endif +#ifdef CINN_HIP_BF16 + __host__ __device__ inline hip_bfloat16 to_hip_bfloat16() const { + return *reinterpret_cast(&x); + } +#endif + __host__ __device__ inline explicit operator bool() const { return (x & 0x7fff) != 0; } diff --git a/test/cpp/pir/cinn/compilation_task_test.cc b/test/cpp/pir/cinn/compilation_task_test.cc index f77d7683bf3b61..61e6a342d096fd 100644 --- a/test/cpp/pir/cinn/compilation_task_test.cc +++ b/test/cpp/pir/cinn/compilation_task_test.cc @@ -26,6 +26,9 @@ #include "paddle/cinn/hlir/framework/pir/compilation_task.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/hlir/framework/pir_compiler.h" +#ifdef PADDLE_WITH_HIP +#include "paddle/cinn/backends/hip/codegen_hip_dev.h" +#endif #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -38,6 +41,19 @@ using cinn::hlir::framework::pir::OpLoweringGroupPtr; using ProgramInfo = std::tuple, std::vector>; +#ifdef PADDLE_WITH_HIP +TEST(CodeGenHipDevice, SourceHeaderIncludesBfloat16) { + const std::string& header = + cinn::backends::hip::CodeGenHipDevice::GetSourceHeader(); + + EXPECT_NE(header.find("#include \"bfloat16.h\""), std::string::npos); + EXPECT_NE(header.find("using cinn::common::bfloat16;"), std::string::npos); + EXPECT_NE(header.find("using cinn::common::bfloat168;"), std::string::npos); + EXPECT_NE(header.find("using cinn::common::bfloat164;"), std::string::npos); + EXPECT_NE(header.find("using cinn::common::bfloat162;"), std::string::npos); +} +#endif + ProgramInfo BuildProgram(std::vector input_shape) { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); ctx->GetOrRegisterDialect();