Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/cinn/backends/hip/codegen_hip_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ namespace hip {

const std::string CodeGenHipDevice::source_header_ = // NOLINT
R"(#define CINN_WITH_HIP
#include "bfloat16.h"
#include "float16.h"
using cinn::common::bfloat16;
using cinn::common::bfloat168;
using cinn::common::bfloat164;
using cinn::common::bfloat162;
using cinn::common::float16;
#include "cinn_hip_runtime_source.h"
)";
Expand Down
33 changes: 33 additions & 0 deletions paddle/cinn/common/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
#endif // __CUDACC__
#endif // CINN_WITH_CUDA

#ifdef CINN_WITH_HIP
#include <hip/hip_runtime.h>
#if defined(__HIPCC__) && HIP_VERSION >= 60100000
#define __HIP_PLATFORM_AMD__
#include <hip/hip_bfloat16.h>
#define CINN_HIP_BF16
#endif
#endif // CINN_WITH_HIP

#ifdef __cplusplus

#ifndef _WIN32
Expand Down Expand Up @@ -80,6 +89,9 @@ struct CINN_ALIGN(2) bfloat16 {
#if defined(CINN_CUDA_BF16)
__nv_bfloat16 tmp = __float2bfloat16(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#elif defined(CINN_HIP_BF16)
hip_bfloat16 tmp(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#else
std::memcpy(&x, reinterpret_cast<char*>(&val) + 2, 2);
#endif
Expand All @@ -91,6 +103,12 @@ struct CINN_ALIGN(2) bfloat16 {
}
#endif

#if defined(CINN_HIP_BF16)
__host__ __device__ inline explicit bfloat16(const hip_bfloat16& val) {
x = *reinterpret_cast<const uint16_t*>(&val);
}
#endif

template <class T>
__host__ __device__ inline explicit bfloat16(const T& val)
: x(bfloat16(static_cast<float>(val)).x) {}
Expand All @@ -103,6 +121,13 @@ struct CINN_ALIGN(2) bfloat16 {
}
#endif

#if defined(CINN_HIP_BF16)
__host__ __device__ inline bfloat16& operator=(const hip_bfloat16& val) {
x = *reinterpret_cast<const uint16_t*>(&val);
return *this;
}
#endif

__host__ __device__ inline bfloat16& operator=(bool b) {
x = b ? 0x3f80 : 0;
return *this;
Expand Down Expand Up @@ -162,6 +187,8 @@ struct CINN_ALIGN(2) bfloat16 {
__host__ __device__ inline operator float() const {
#ifdef CINN_CUDA_BF16
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#elif defined(CINN_HIP_BF16)
return static_cast<float>(*reinterpret_cast<const hip_bfloat16*>(&x));
#else
float val = 0.f;
uint16_t temp = x;
Expand All @@ -177,6 +204,12 @@ struct CINN_ALIGN(2) bfloat16 {
}
#endif

#ifdef CINN_HIP_BF16
__host__ __device__ inline hip_bfloat16 to_hip_bfloat16() const {
return *reinterpret_cast<const hip_bfloat16*>(&x);
}
#endif

__host__ __device__ inline explicit operator bool() const {
return (x & 0x7fff) != 0;
}
Expand Down
16 changes: 16 additions & 0 deletions test/cpp/pir/cinn/compilation_task_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include "paddle/cinn/hlir/framework/pir/compilation_task.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/cinn/backends/hip/codegen_hip_dev.h"
#endif
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand All @@ -38,6 +41,19 @@ using cinn::hlir::framework::pir::OpLoweringGroupPtr;

using ProgramInfo = std::tuple<std::shared_ptr<::pir::Program>,
std::vector<OpLoweringGroupPtr>>;
#ifdef PADDLE_WITH_HIP
TEST(CodeGenHipDevice, SourceHeaderIncludesBfloat16) {
const std::string& header =
cinn::backends::hip::CodeGenHipDevice::GetSourceHeader();

EXPECT_NE(header.find("#include \"bfloat16.h\""), std::string::npos);
EXPECT_NE(header.find("using cinn::common::bfloat16;"), std::string::npos);
EXPECT_NE(header.find("using cinn::common::bfloat168;"), std::string::npos);
EXPECT_NE(header.find("using cinn::common::bfloat164;"), std::string::npos);
EXPECT_NE(header.find("using cinn::common::bfloat162;"), std::string::npos);
}
#endif

ProgramInfo BuildProgram(std::vector<int64_t> input_shape) {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down