From bda8c59a9f407a6cbcf473d7fb18d62377016b49 Mon Sep 17 00:00:00 2001 From: Dmitry Vodopyanov Date: Wed, 24 Jun 2026 17:19:38 +0200 Subject: [PATCH 1/2] [SYCL] Implement `sycl_ext_oneapi_fp4` extension --- .../oneapi/experimental/float_4bit/types.hpp | 977 ++++++++++++++++++ sycl/source/feature_test.hpp.in | 1 + .../Experimental/fp4/e2m1_cri_conversion.cpp | 159 +++ .../fp4/e2m1_x2_cri_conversion.cpp | 323 ++++++ .../Experimental/fp4/lit.local.cfg.py | 9 + sycl/unittests/Extensions/CMakeLists.txt | 1 + sycl/unittests/Extensions/fp4/CMakeLists.txt | 3 + sycl/unittests/Extensions/fp4/fp4_e2m1.cpp | 663 ++++++++++++ 8 files changed, 2136 insertions(+) create mode 100644 sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp create mode 100644 sycl/test-e2e/Experimental/fp4/e2m1_cri_conversion.cpp create mode 100644 sycl/test-e2e/Experimental/fp4/e2m1_x2_cri_conversion.cpp create mode 100644 sycl/test-e2e/Experimental/fp4/lit.local.cfg.py create mode 100644 sycl/unittests/Extensions/fp4/CMakeLists.txt create mode 100644 sycl/unittests/Extensions/fp4/fp4_e2m1.cpp diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp new file mode 100644 index 0000000000000..e7060fab668d9 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp @@ -0,0 +1,977 @@ +//==----------- types.hpp - sycl_ext_oneapi_fp4 ------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __SYCL_DEVICE_ONLY__ + +namespace sycl { +namespace detail { +using fp4_float16_vec2 = _Float16 __attribute__((ext_vector_type(2))); +using fp4_bfloat16_vec2 = __bf16 __attribute__((ext_vector_type(2))); +using fp4_uint8_vec2 = uint8_t __attribute__((ext_vector_type(2))); +} // namespace detail +} // namespace sycl + +// FP4 builtins. The SPIR-V translator maps these names to the corresponding +// SPV_INTEL_float4 / SPV_INTEL_fp_conversions instructions. Scalar builtins +// take/return a 4-bit value held in an 8-bit register; the result is in the +// low 4 bits and the upper 4 bits are unused. Vec2 builtins operate on a +// pair of values; the encode-side returns a packed pair of nibbles in a +// single 8-bit value, while the decode-side accepts a vec2 of nibbles (each +// nibble in the low bits of its lane) and returns a vec2 of floats. + +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ClampConvertFP16ToE2M1INTEL(_Float16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ClampConvertFP16ToE2M1INTEL( + ::sycl::detail::fp4_float16_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_ClampConvertBF16ToE2M1INTEL(__bf16) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t + __builtin_spirv_ClampConvertBF16ToE2M1INTEL( + ::sycl::detail::fp4_bfloat16_vec2) noexcept; + +extern __DPCPP_SYCL_EXTERNAL _Float16 +__builtin_spirv_ConvertE2M1ToFP16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::fp4_float16_vec2 + __builtin_spirv_ConvertE2M1ToFP16INTEL( + ::sycl::detail::fp4_uint8_vec2) noexcept; +extern __DPCPP_SYCL_EXTERNAL __bf16 +__builtin_spirv_ConvertE2M1ToBF16INTEL(uint8_t) noexcept; +extern __DPCPP_SYCL_EXTERNAL ::sycl::detail::fp4_bfloat16_vec2 + __builtin_spirv_ConvertE2M1ToBF16INTEL( + ::sycl::detail::fp4_uint8_vec2) noexcept; + +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_StochasticRoundFP16ToE2M1INTEL( + _Float16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; +extern __DPCPP_SYCL_EXTERNAL uint8_t +__builtin_spirv_StochasticRoundBF16ToE2M1INTEL( + __bf16, uint32_t, __attribute__((opencl_private)) uint32_t *) noexcept; + +#endif // __SYCL_DEVICE_ONLY__ + +namespace sycl { +inline namespace _V1 { +namespace ext::oneapi::experimental { + +enum class rounding { + to_even, + toward_zero, +}; + +struct stochastic_seed { + explicit stochastic_seed(uint32_t *pseed) : pseed(pseed) {} + uint32_t *const pseed; +}; + +namespace detail { + +template static inline int Fp4BitWidth(T x) noexcept { + int width = 0; + while (x != 0u) { + ++width; + x >>= 1; + } + return width; +} + +template struct Fp4SourceTraits; + +template <> struct Fp4SourceTraits { + using UInt = uint32_t; + static constexpr size_t ExpBits = 8; + static constexpr size_t FracBits = 23; + static constexpr int Bias = 127; +}; + +template <> struct Fp4SourceTraits { + using UInt = uint16_t; + static constexpr size_t ExpBits = 5; + static constexpr size_t FracBits = 10; + static constexpr int Bias = 15; +}; + +template <> struct Fp4SourceTraits { + using UInt = uint16_t; + static constexpr size_t ExpBits = 8; + static constexpr size_t FracBits = 7; + static constexpr int Bias = 127; +}; + +template struct Fp4IntSourceTraits { + using UnsignedT = std::make_unsigned_t; + static constexpr bool IsSigned = std::numeric_limits::is_signed; + static constexpr int ValueBits = std::numeric_limits::digits; +}; + +// E2M1 finite format constants. +struct FP4E2M1Traits { + static constexpr int Ebits = 2; + static constexpr int Mbits = 1; + static constexpr uint8_t ExpAllOnes = 0x3u; + static constexpr uint8_t MaxFrac = 0x1u; + static constexpr int Bias = 1; + static constexpr int Emin = 1 - Bias; // 0 + // E2M1 has no Inf and no NaN; all-ones exponent is the max normal. + static constexpr uint8_t MaxFiniteExpField = ExpAllOnes; + static constexpr uint8_t MaxFiniteFracField = MaxFrac; + static constexpr uint8_t MaxFiniteCode = + static_cast((MaxFiniteExpField << Mbits) | MaxFiniteFracField); + static constexpr int MaxFiniteExp = + static_cast(MaxFiniteExpField) - Bias; // 2 + static constexpr uint64_t MinNormalMantissa = uint64_t{1} << Mbits; // 2 + static constexpr uint64_t OverflowMantissa = uint64_t{1} << (Mbits + 1); // 4 + static constexpr uint64_t MaxFiniteMantissa = + MinNormalMantissa + MaxFiniteFracField; // 3 +}; + +// CPU host conversion: signed/unsigned integer to E2M1 nibble (low 4 bits). +template > +static inline uint8_t ConvertIntToFP4_CPU(T f, rounding R) noexcept { + using UnsignedT = typename Traits::UnsignedT; + using Format = FP4E2M1Traits; + + const uint8_t sign = + (Traits::IsSigned && f < 0) ? static_cast(0x8u) : 0u; + UnsignedT magnitude = 0; + + if constexpr (Traits::IsSigned) { + const UnsignedT bits = static_cast(f); + magnitude = f < 0 ? static_cast(UnsignedT{0} - bits) : bits; + } else { + magnitude = static_cast(f); + } + + if (magnitude == 0) + return sign; + + int unbiasedExp = Fp4BitWidth(static_cast(magnitude)) - 1; + if (unbiasedExp > Format::MaxFiniteExp) + return static_cast(sign | Format::MaxFiniteCode); + + const int shift = unbiasedExp - Format::Mbits; + uint64_t mantissa = 0u; + if (shift <= 0) { + mantissa = static_cast(magnitude) << (-shift); + } else { + const uint64_t truncated = static_cast(magnitude) >> shift; + const uint64_t remainderMask = (uint64_t{1} << shift) - 1u; + const uint64_t remainder = + static_cast(magnitude) & remainderMask; + + mantissa = truncated; + if (remainder != 0u && R == rounding::to_even) { + const uint64_t half = uint64_t{1} << (shift - 1); + if (remainder > half || + (remainder == half && (truncated & uint64_t{1}) != 0u)) + ++mantissa; + } + } + + if (mantissa >= Format::OverflowMantissa) { + mantissa = Format::MinNormalMantissa; + ++unbiasedExp; + } + + if (unbiasedExp > Format::MaxFiniteExp || + (unbiasedExp == Format::MaxFiniteExp && + mantissa > Format::MaxFiniteMantissa)) + return static_cast(sign | Format::MaxFiniteCode); + + const uint8_t expField = + static_cast(unbiasedExp + Format::Bias); + const uint8_t fracField = + static_cast(mantissa - Format::MinNormalMantissa); + return static_cast( + sign | static_cast(expField << Format::Mbits) | fracField); +} + +// CPU host conversion: binary floating point to E2M1 nibble (low 4 bits). +template > +static inline uint8_t ConvertFloatToFP4_CPU(T f, rounding R) noexcept { + using UInt = typename Traits::UInt; + using Format = FP4E2M1Traits; + + constexpr UInt SignMask = UInt{1} << (Traits::ExpBits + Traits::FracBits); + constexpr UInt FracMask = (UInt{1} << Traits::FracBits) - UInt{1}; + constexpr UInt ExpMask = ((UInt{1} << Traits::ExpBits) - UInt{1}) + << Traits::FracBits; + constexpr UInt ExpAllOnes = (UInt{1} << Traits::ExpBits) - UInt{1}; + + UInt bits; + std::memcpy(&bits, &f, sizeof(bits)); + + const uint8_t sign = (bits & SignMask) ? 0x8u : 0x0u; + bits &= ~SignMask; + + const UInt exp = (bits & ExpMask) >> Traits::FracBits; + const UInt frac = bits & FracMask; + + // Inf and NaN both clamp to max normal preserving sign (E2M1 has neither). + if (exp == ExpAllOnes) + return static_cast(sign | Format::MaxFiniteCode); + + if (exp == 0u && frac == 0u) + return sign; + + uint64_t significand = 0u; + int leadingBit = 0; + int unbiasedExp = 0; + + if (exp != 0u) { + significand = + (uint64_t{1} << Traits::FracBits) | static_cast(frac); + leadingBit = static_cast(Traits::FracBits); + unbiasedExp = static_cast(exp) - Traits::Bias; + } else { + significand = static_cast(frac); + uint64_t tmp = significand; + leadingBit = -1; + while (tmp != 0u) { + ++leadingBit; + tmp >>= 1; + } + unbiasedExp = + 1 - Traits::Bias - static_cast(Traits::FracBits) + leadingBit; + } + + auto roundShiftRight = [&](uint64_t value, int shift) -> uint64_t { + if (shift <= 0) + return value; + if (shift >= 64) + return 0u; + + const uint64_t truncated = value >> shift; + const uint64_t remainderMask = (uint64_t{1} << shift) - 1u; + const uint64_t remainder = value & remainderMask; + + if (remainder == 0u || R == rounding::toward_zero) + return truncated; + + const uint64_t half = uint64_t{1} << (shift - 1); + if (remainder > half) + return truncated + 1u; + if (remainder < half) + return truncated; + return (truncated & 1u) != 0u ? truncated + 1u : truncated; + }; + + if (unbiasedExp > Format::MaxFiniteExp) + return static_cast(sign | Format::MaxFiniteCode); + + if (unbiasedExp == Format::MaxFiniteExp) { + const uint64_t lhs = significand << Format::Mbits; + const uint64_t rhs = Format::MaxFiniteMantissa << leadingBit; + if (lhs > rhs) + return static_cast(sign | Format::MaxFiniteCode); + } + + if (unbiasedExp < Format::Emin) { + const int shift = + leadingBit - unbiasedExp - Format::Bias - Format::Mbits + 1; + uint64_t mantissa = shift > 0 ? roundShiftRight(significand, shift) + : (significand << (-shift)); + + if (mantissa == 0u) + return sign; + + if (mantissa >= Format::MinNormalMantissa) + return static_cast(sign | + (uint8_t{1} << Format::Mbits)); + + return static_cast(sign | static_cast(mantissa)); + } + + const int shift = leadingBit - Format::Mbits; + uint64_t mantissa = shift > 0 ? roundShiftRight(significand, shift) + : (significand << (-shift)); + + if (mantissa >= Format::OverflowMantissa) { + mantissa = Format::MinNormalMantissa; + ++unbiasedExp; + } + + if (unbiasedExp > Format::MaxFiniteExp || + (unbiasedExp == Format::MaxFiniteExp && + mantissa > Format::MaxFiniteMantissa)) + return static_cast(sign | Format::MaxFiniteCode); + + const uint8_t expField = + static_cast(unbiasedExp + Format::Bias); + const uint8_t fracField = + static_cast(mantissa - Format::MinNormalMantissa); + return static_cast( + sign | static_cast(expField << Format::Mbits) | fracField); +} + +template +struct Fp4HasFloatTraits : std::false_type {}; +template +struct Fp4HasFloatTraits::ExpBits), + decltype(Fp4SourceTraits::FracBits), + decltype(Fp4SourceTraits::Bias)>> + : std::true_type {}; + +// CPU host conversion: E2M1 nibble (low 4 bits of `code`) to ToT. +template +static inline ToT ConvertFromFP4ToBinaryFloat_CPU(uint8_t code, + rounding R) noexcept { + using Format = FP4E2M1Traits; + constexpr uint8_t SignBit = 0x8u; + constexpr uint8_t ExpAllOnes = Format::ExpAllOnes; + constexpr uint8_t FracMask = Format::MaxFrac; + + const bool negative = (code & SignBit) != 0u; + const uint8_t exp = static_cast((code >> Format::Mbits) & ExpAllOnes); + const uint8_t frac = static_cast(code & FracMask); + + uint32_t significand = 0u; + int exp2 = 0; + + if (exp == 0u) { + if (frac == 0u) { + significand = 0u; + } else { + significand = frac; + exp2 = Format::Emin; + } + } else { + significand = + static_cast((1u << Format::Mbits) + frac); + exp2 = static_cast(exp) - Format::Bias; + } + + if constexpr (Fp4HasFloatTraits::value) { + using Traits = Fp4SourceTraits; + using UInt = typename Traits::UInt; + + constexpr UInt ExpAllOnesDst = ((UInt{1} << Traits::ExpBits) - UInt{1}) + << Traits::FracBits; + constexpr UInt FracMaskDst = (UInt{1} << Traits::FracBits) - UInt{1}; + + UInt bits = 0; + if (significand == 0u) { + bits = + negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; + } else { + const int sigBits = Fp4BitWidth(significand); + const int unbiasedExp = exp2 + sigBits - 1 - Format::Mbits; + const UInt signBit = + negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; + + const int shift = static_cast(Traits::FracBits) - (sigBits - 1); + const UInt aligned = static_cast(significand) << shift; + const UInt expField = static_cast(unbiasedExp + Traits::Bias) + << Traits::FracBits; + bits = signBit | expField | (aligned & FracMaskDst); + } + + (void)R; + (void)ExpAllOnesDst; + return __builtin_bit_cast(ToT, bits); + } else if constexpr (std::is_integral_v) { + using Traits = Fp4IntSourceTraits; + using UnsignedT = typename Traits::UnsignedT; + + if (significand == 0u) + return ToT{}; + + const int shift = exp2 - Format::Mbits; + uint64_t magnitude = 0u; + + if (shift >= 0) { + magnitude = static_cast(significand) << shift; + } else { + const int rshift = -shift; + magnitude = static_cast(significand) >> rshift; + // rounding::toward_zero: discard remainder bits. + (void)R; + } + + if (magnitude == 0u) + return ToT{}; + + if (Fp4BitWidth(magnitude) > Traits::ValueBits) { + if constexpr (Traits::IsSigned) + return negative ? std::numeric_limits::min() + : std::numeric_limits::max(); + else + return negative ? ToT{0} : std::numeric_limits::max(); + } + + const UnsignedT narrowed = static_cast(magnitude); + if constexpr (Traits::IsSigned) + return static_cast(negative ? -static_cast(narrowed) + : static_cast(narrowed)); + return static_cast(narrowed); + } else { + (void)R; + return ToT{}; + } +} + +// Pack two nibbles into a single byte: lo in low 4 bits, hi in high 4 bits. +static inline uint8_t Fp4Pack(uint8_t lo, uint8_t hi) noexcept { + return static_cast((lo & 0x0Fu) | ((hi & 0x0Fu) << 4)); +} + +// Extract element i (0 or 1) from packed byte. +static inline uint8_t Fp4Extract(uint8_t packed, size_t i) noexcept { + return static_cast((packed >> (i * 4)) & 0x0Fu); +} + +} // namespace detail + +template class fp4_e2m1_x { + static_assert(N == 1 || N == 2, + "fp4_e2m1_x: Template argument N must be 1 or 2"); + + template >>> + uint8_t ConvertToFP4(T h) { +#ifdef __SYCL_DEVICE_ONLY__ + if constexpr (std::is_same_v, char> || + std::is_same_v, signed char> || + std::is_same_v, unsigned char>) { + const _Float16 v = static_cast<_Float16>(h); + return __builtin_spirv_ClampConvertFP16ToE2M1INTEL(v); + } + return detail::ConvertIntToFP4_CPU(h, rounding::to_even); +#else + return detail::ConvertIntToFP4_CPU(h, rounding::to_even); +#endif + } + + uint8_t ConvertToFP4(sycl::half h) { +#ifdef __SYCL_DEVICE_ONLY__ + const _Float16 v = sycl::bit_cast<_Float16>(h); + return __builtin_spirv_ClampConvertFP16ToE2M1INTEL(v); +#else + return detail::ConvertFloatToFP4_CPU(h, rounding::to_even); +#endif + } + +#ifdef __SYCL_DEVICE_ONLY__ + uint8_t ConvertToFP4_Vec2(::sycl::detail::fp4_float16_vec2 h) { + return __builtin_spirv_ClampConvertFP16ToE2M1INTEL(h); + } +#endif + + uint8_t ConvertToFP4(float h) { + return detail::ConvertFloatToFP4_CPU(h, rounding::to_even); + } + +#ifdef __SYCL_DEVICE_ONLY__ + uint8_t ConvertBF16ToFP4_Vec2(::sycl::detail::fp4_bfloat16_vec2 h) { + return __builtin_spirv_ClampConvertBF16ToE2M1INTEL(h); + } +#endif + + uint8_t ConvertBF16ToFP4(bfloat16 h) { +#ifdef __SYCL_DEVICE_ONLY__ + return __builtin_spirv_ClampConvertBF16ToE2M1INTEL( + sycl::bit_cast<__bf16>(h)); +#else + return detail::ConvertFloatToFP4_CPU(h, rounding::to_even); +#endif + } + + template T ConvertFromFP4(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + sycl::half hi = __builtin_spirv_ConvertE2M1ToFP16INTEL(v); + return static_cast(hi); +#else + return detail::ConvertFromFP4ToBinaryFloat_CPU(v, rounding::toward_zero); +#endif + } + + template T ConvertFromFP4Int(uint8_t v) const { + return detail::ConvertFromFP4ToBinaryFloat_CPU(v, rounding::toward_zero); + } + + void ConvertFromFP4_Vec2(sycl::marray &ret) const { +#ifdef __SYCL_DEVICE_ONLY__ + const ::sycl::detail::fp4_uint8_vec2 packed{ + detail::Fp4Extract(vals[0], 0), detail::Fp4Extract(vals[0], 1)}; + ::sycl::detail::fp4_float16_vec2 hi = + __builtin_spirv_ConvertE2M1ToFP16INTEL(packed); + ret[0] = sycl::bit_cast(hi[0]); + ret[1] = sycl::bit_cast(hi[1]); +#else + for (size_t i = 0; i < 2; ++i) + ret[i] = detail::ConvertFromFP4ToBinaryFloat_CPU( + detail::Fp4Extract(vals[0], i), rounding::toward_zero); +#endif + } + + bfloat16 ConvertBF16FromFP4(uint8_t v) const { +#ifdef __SYCL_DEVICE_ONLY__ + return sycl::bit_cast(__builtin_spirv_ConvertE2M1ToBF16INTEL(v)); +#else + return detail::ConvertFromFP4ToBinaryFloat_CPU( + v, rounding::toward_zero); +#endif + } + + void ConvertBF16FromFP4_Vec2(sycl::marray &ret) const { +#ifdef __SYCL_DEVICE_ONLY__ + const ::sycl::detail::fp4_uint8_vec2 packed{ + detail::Fp4Extract(vals[0], 0), detail::Fp4Extract(vals[0], 1)}; + ::sycl::detail::fp4_bfloat16_vec2 hi = + __builtin_spirv_ConvertE2M1ToBF16INTEL(packed); + ret[0] = sycl::bit_cast(hi[0]); + ret[1] = sycl::bit_cast(hi[1]); +#else + for (size_t i = 0; i < 2; ++i) + ret[i] = detail::ConvertFromFP4ToBinaryFloat_CPU( + detail::Fp4Extract(vals[0], i), rounding::toward_zero); +#endif + } + + void CheckConstraints(rounding r) const { + assert(r == rounding::to_even && + "fp4_e2m1_x: only rounding::to_even is supported"); + } + + // Store one nibble at element index i (0 or 1). + void StoreNibble(size_t i, uint8_t nibble) { + if (i == 0) + vals[0] = static_cast((vals[0] & 0xF0u) | (nibble & 0x0Fu)); + else + vals[0] = static_cast((vals[0] & 0x0Fu) | + (static_cast(nibble & 0x0Fu) + << 4)); + } + +#ifdef __SYCL_DEVICE_ONLY__ +#define CONVERT_TO_FP4(VecType, CastType, in, Prefix) \ + if constexpr (N == 1) { \ + vals[0] = Convert##Prefix##ToFP4(in[0]); \ + } else { \ + const VecType vec{sycl::bit_cast(in[0]), \ + sycl::bit_cast(in[1])}; \ + vals[0] = Convert##Prefix##ToFP4_Vec2(vec); \ + } +#else +#define CONVERT_TO_FP4(VecType, CastType, in, Prefix) \ + if constexpr (N == 1) { \ + vals[0] = Convert##Prefix##ToFP4(in[0]); \ + } else { \ + const uint8_t lo = Convert##Prefix##ToFP4(in[0]); \ + const uint8_t hi = Convert##Prefix##ToFP4(in[1]); \ + vals[0] = detail::Fp4Pack(lo, hi); \ + } +#endif + +public: + fp4_e2m1_x() = default; + fp4_e2m1_x(const fp4_e2m1_x &) = default; + ~fp4_e2m1_x() = default; + fp4_e2m1_x &operator=(const fp4_e2m1_x &) = default; + + // Construct from pack of half, bfloat16, float. + // Available only when the size of the pack is equal to N. + template , half>) && ...) || + ((std::is_same_v, bfloat16>) && ...) || + ((std::is_same_v, float>) && ...))>> + explicit fp4_e2m1_x(Types... v) { + if constexpr (((std::is_same_v, bfloat16>) && ...)) { + const bfloat16 in[N] = {v...}; + CONVERT_TO_FP4(::sycl::detail::fp4_bfloat16_vec2, __bf16, in, BF16); + } else if constexpr (((std::is_same_v, half>) && ...)) { + const sycl::half in[N] = {v...}; + CONVERT_TO_FP4(::sycl::detail::fp4_float16_vec2, _Float16, in, ); + } else { + const float in[N] = {v...}; + if constexpr (N == 1) { + vals[0] = ConvertToFP4(in[0]); + } else { + const uint8_t lo = ConvertToFP4(in[0]); + const uint8_t hi = ConvertToFP4(in[1]); + vals[0] = detail::Fp4Pack(lo, hi); + } + } + } + + // Construct from an array of half, bfloat16, float. + explicit fp4_e2m1_x(sycl::half const (&v)[N], + rounding r = rounding::to_even) { + CheckConstraints(r); + CONVERT_TO_FP4(::sycl::detail::fp4_float16_vec2, _Float16, v, ); + } + + explicit fp4_e2m1_x(bfloat16 const (&v)[N], rounding r = rounding::to_even) { + CheckConstraints(r); + CONVERT_TO_FP4(::sycl::detail::fp4_bfloat16_vec2, __bf16, v, BF16); + } + + explicit fp4_e2m1_x(float const (&v)[N], rounding r = rounding::to_even) { + CheckConstraints(r); + if constexpr (N == 1) { + vals[0] = ConvertToFP4(v[0]); + } else { + const uint8_t lo = ConvertToFP4(v[0]); + const uint8_t hi = ConvertToFP4(v[1]); + vals[0] = detail::Fp4Pack(lo, hi); + } + } + + // Construct from an marray of half, bfloat16, float. + explicit fp4_e2m1_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); + CONVERT_TO_FP4(::sycl::detail::fp4_float16_vec2, _Float16, v, ); + } + + explicit fp4_e2m1_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); + CONVERT_TO_FP4(::sycl::detail::fp4_bfloat16_vec2, __bf16, v, BF16); + } + + explicit fp4_e2m1_x(const sycl::marray &v, + rounding r = rounding::to_even) { + CheckConstraints(r); + if constexpr (N == 1) { + vals[0] = ConvertToFP4(v[0]); + } else { + const uint8_t lo = ConvertToFP4(v[0]); + const uint8_t hi = ConvertToFP4(v[1]); + vals[0] = detail::Fp4Pack(lo, hi); + } + } + + // Construct with stochastic rounding from an array of half, bfloat16. + explicit fp4_e2m1_x([[maybe_unused]] half const (&in)[N], + [[maybe_unused]] const stochastic_seed &seed) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; + uint8_t nibbles[2] = {0, 0}; + for (size_t i = 0; i < N; ++i) { + const _Float16 v = sycl::bit_cast<_Float16>(in[i]); + nibbles[i] = __builtin_spirv_StochasticRoundFP16ToE2M1INTEL( + v, current_seed, + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); + current_seed = next_seed; + next_seed = 0; + } + if constexpr (N == 1) + vals[0] = static_cast(nibbles[0] & 0x0Fu); + else + vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); +#endif + } + + explicit fp4_e2m1_x([[maybe_unused]] bfloat16 const (&in)[N], + [[maybe_unused]] const stochastic_seed &seed) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; + uint8_t nibbles[2] = {0, 0}; + for (size_t i = 0; i < N; ++i) { + nibbles[i] = __builtin_spirv_StochasticRoundBF16ToE2M1INTEL( + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); + current_seed = next_seed; + next_seed = 0; + } + if constexpr (N == 1) + vals[0] = static_cast(nibbles[0] & 0x0Fu); + else + vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); +#endif + } + + // Construct with stochastic rounding from an marray of half, bfloat16. + explicit fp4_e2m1_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; + uint8_t nibbles[2] = {0, 0}; + for (size_t i = 0; i < N; ++i) { + const _Float16 v = sycl::bit_cast<_Float16>(in[i]); + nibbles[i] = __builtin_spirv_StochasticRoundFP16ToE2M1INTEL( + v, current_seed, + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); + current_seed = next_seed; + next_seed = 0; + } + if constexpr (N == 1) + vals[0] = static_cast(nibbles[0] & 0x0Fu); + else + vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); +#endif + } + + explicit fp4_e2m1_x([[maybe_unused]] const sycl::marray &in, + [[maybe_unused]] const stochastic_seed &seed) { +#ifdef __SYCL_DEVICE_ONLY__ + uint32_t current_seed = *seed.pseed; + uint32_t next_seed = 0; + uint8_t nibbles[2] = {0, 0}; + for (size_t i = 0; i < N; ++i) { + nibbles[i] = __builtin_spirv_StochasticRoundBF16ToE2M1INTEL( + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); + current_seed = next_seed; + next_seed = 0; + } + if constexpr (N == 1) + vals[0] = static_cast(nibbles[0] & 0x0Fu); + else + vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); +#else + throw std::runtime_error( + "stochastic rounding constructors are not supported on host"); +#endif + } + + // Construct from integer types. Available only when N==1. + template > + explicit fp4_e2m1_x(short val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(int val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(long val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(long long val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(unsigned short val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(unsigned int val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(unsigned long val) { + vals[0] = ConvertToFP4(val); + } + template > + explicit fp4_e2m1_x(unsigned long long val) { + vals[0] = ConvertToFP4(val); + } + + // Assign (operator) from half, bfloat16, float, and integer types. + // Available only when N==1. + template > + fp4_e2m1_x &operator=(sycl::half val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(bfloat16 val) { + vals[0] = ConvertBF16ToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(float val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(short val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(int val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(long val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(long long val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(unsigned short val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(unsigned int val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(unsigned long val) { + vals[0] = ConvertToFP4(val); + return *this; + } + template > + fp4_e2m1_x &operator=(unsigned long long val) { + vals[0] = ConvertToFP4(val); + return *this; + } + + // Convert to half, bfloat16, float. Available only when N==1. + template > + explicit operator half() const { + return ConvertFromFP4(vals[0] & 0x0Fu); + } + template > + explicit operator bfloat16() const { + return ConvertBF16FromFP4(vals[0] & 0x0Fu); + } + template > + explicit operator float() const { + return ConvertFromFP4(vals[0] & 0x0Fu); + } + + // Convert to integer types. Available only when N==1. + template > + explicit operator char() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator signed char() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator short() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator int() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator long() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator long long() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator unsigned char() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator unsigned short() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator unsigned int() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator unsigned long() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + template > + explicit operator unsigned long long() const { + return ConvertFromFP4Int(vals[0] & 0x0Fu); + } + + // Convert to bool. Available only when N==1. + // false iff +0 or -0; otherwise true. + template > + explicit operator bool() const { + const uint8_t low = vals[0] & 0x0Fu; + return low != 0x0u && low != 0x8u; + } + + // Convert to marray of half, bfloat16, float. + explicit operator sycl::marray() const { + sycl::marray ret; + if constexpr (N == 1) + ret[0] = ConvertFromFP4(vals[0] & 0x0Fu); + else + ConvertFromFP4_Vec2(ret); + return ret; + } + + explicit operator sycl::marray() const { + sycl::marray ret; + if constexpr (N == 1) + ret[0] = ConvertBF16FromFP4(vals[0] & 0x0Fu); + else + ConvertBF16FromFP4_Vec2(ret); + return ret; + } + + explicit operator sycl::marray() const { + sycl::marray ret; + for (size_t i = 0; i < N; ++i) + ret[i] = detail::ConvertFromFP4ToBinaryFloat_CPU( + detail::Fp4Extract(vals[0], i), rounding::toward_zero); + return ret; + } + + // Intentionally public to allow access to the raw values. + // Element 0 is in the low 4 bits of vals[0]. + // Element 1 (if it exists) is in the high 4 bits of vals[0]. + uint8_t vals[(N + 1) / 2]; +#undef CONVERT_TO_FP4 +}; + +// Deduction guide. Available only when the size of the pack is greater than +// zero. +template fp4_e2m1_x(Ts...) -> fp4_e2m1_x; + +using fp4_e2m1 = fp4_e2m1_x<1>; +using fp4_e2m1_x2 = fp4_e2m1_x<2>; + +} // namespace ext::oneapi::experimental +} // namespace _V1 +} // namespace sycl diff --git a/sycl/source/feature_test.hpp.in b/sycl/source/feature_test.hpp.in index 35db9c3a97c89..7cf62c1a60933 100644 --- a/sycl/source/feature_test.hpp.in +++ b/sycl/source/feature_test.hpp.in @@ -126,6 +126,7 @@ inline namespace _V1 { #define SYCL_KHR_WORK_ITEM_QUERIES 1 #define SYCL_KHR_GROUP_INTERFACE 1 #define SYCL_EXT_ONEAPI_FP8 1 +#define SYCL_EXT_ONEAPI_FP4 1 // Unfinished KHR extensions. These extensions are only available if the // __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS macro is defined. diff --git a/sycl/test-e2e/Experimental/fp4/e2m1_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp4/e2m1_cri_conversion.cpp new file mode 100644 index 0000000000000..d618fd0e856c0 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp4/e2m1_cri_conversion.cpp @@ -0,0 +1,159 @@ +// REQUIRES: intel_feature_gpu_cri +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_INTEL_float4,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// UNSUPPORTED: target-nvidia, target-amd, spirv-backend +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver, and the +// SPIR-V backend does not support the required SPIR-V extensions + +#include + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +template int test_fp4_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1(static_cast(1.5)); + + queue.single_task([=]() { + fp4_e2m1 value = data[0]; + T f = static_cast(value); + f += static_cast(1.0f); + data[0] = fp4_e2m1(f); + }); + queue.wait_and_throw(); + + // E2M1 representable values: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (and negatives). + // 1.5 + 1.0 = 2.5 -> rounds to either 2 (to_even) or 3. + // We compare via an fp4 round-trip so the expected is well-defined. + fp4_e2m1 expected(static_cast(static_cast(static_cast(1.5)) + + 1.0f)); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(out - expected_out) > 0.0f) + return 1; + + return 0; +} + +int test_boolean_conversion(sycl::queue &queue, float test_value, + bool expected) { + auto *data = sycl::malloc_shared(1, queue); + auto *res = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1(test_value); + queue.single_task([=]() { + fp4_e2m1 value = data[0]; + res[0] = static_cast(value); + }); + queue.wait_and_throw(); + int ret = res[0] == expected ? 0 : 1; + sycl::free(data, queue); + sycl::free(res, queue); + return ret; +} + +template +int test_single_element_carray_constructor(sycl::queue &queue) { + T input[1] = {static_cast(1.5f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1(input); + + queue.single_task([=]() { + fp4_e2m1 value = data[0]; + T output[1] = {static_cast(value) + static_cast(1.0f)}; + data[0] = fp4_e2m1(output); + }); + queue.wait_and_throw(); + + // 1.5 + 1.0 = 2.5; the closest representable values are 2.0 and 3.0, + // round-to-even resolves the tie to 2.0 (frac=0). + fp4_e2m1 expected(static_cast(2.0f)); + T out = static_cast(data[0]); + T expected_out = static_cast(expected); + + sycl::free(data, queue); + if (std::fabs(static_cast(out) - static_cast(expected_out)) > + 0.0f) + return 1; + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(1.5f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1(input); + + queue.single_task([=]() { + fp4_e2m1 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.0f); + data[0] = fp4_e2m1(f); + }); + queue.wait_and_throw(); + + sycl::marray expected_input(static_cast(2.0f)); + fp4_e2m1 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + if (std::fabs(out[0] - expected_out[0]) > 0.0f) + return 1; + + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + int ret = test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + + // bool conversion: only +0/-0 -> false; everything else -> true (E2M1 has + // no NaN representation). + ret |= test_boolean_conversion(queue, 0.0f, false); + ret |= test_boolean_conversion(queue, -0.0f, false); + ret |= test_boolean_conversion(queue, 1.0f, true); + ret |= test_boolean_conversion(queue, -1.0f, true); + ret |= test_boolean_conversion(queue, 0.5f, true); + + ret |= test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor(queue); + ret |= test_single_element_carray_constructor( + queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp4/e2m1_x2_cri_conversion.cpp b/sycl/test-e2e/Experimental/fp4/e2m1_x2_cri_conversion.cpp new file mode 100644 index 0000000000000..0c31e8d456825 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp4/e2m1_x2_cri_conversion.cpp @@ -0,0 +1,323 @@ +// REQUIRES: intel_feature_gpu_cri +// RUN: %{build} -Xclang -freg-struct-return -Xspirv-translator=spir64 --spirv-ext=+SPV_INTEL_fp_conversions,+SPV_INTEL_float4,+SPV_KHR_bfloat16 -o %t.out +// RUN: %{run} SYCL_UR_TRACE=1 %t.out + +// UNSUPPORTED: target-nvidia, target-amd, spirv-backend +// UNSUPPORTED-INTENDED: only supported by backends with CRI driver, and the +// SPIR-V backend does not support the required SPIR-V extensions + +#include + +#include +#include +#include +#include +#include + +using namespace sycl::ext::oneapi::experimental; + +bool equal_with_zero_sign(float actual, float expected) { + if (actual != expected) + return false; + if (expected == 0.0f) + return std::signbit(actual) == std::signbit(expected); + return true; +} + +template +int test_explicit_to_even_carray_constructor(sycl::queue &queue) { + // E2M1 exactly representable: 0.5 and -3.0 (positive subnormal & normal). + T input[2] = {static_cast(0.5f), static_cast(-3.0f)}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp4_e2m1_x2(input, rounding::to_even); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 0.5f) + ret = 1; + if (static_cast(out[1]) != -3.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template +int test_explicit_to_even_marray_constructor(sycl::queue &queue) { + sycl::marray input(static_cast(1.5f), static_cast(-4.0f)); + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp4_e2m1_x2(input, rounding::to_even); + + queue.single_task([=]() { + sycl::marray unpacked = static_cast>(data[0]); + out[0] = unpacked[0]; + out[1] = unpacked[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (static_cast(out[0]) != 1.5f) + ret = 1; + if (static_cast(out[1]) != -4.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_negative_zero(sycl::queue &queue) { + const float input[2] = {-0.0f, 6.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp4_e2m1_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp4_e2m1_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (!equal_with_zero_sign(out[0], -0.0f)) + ret = 1; + if (out[1] != 6.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_subnormals(sycl::queue &queue) { + // The only positive subnormal is 0.5; -0.5 is the only negative subnormal. + const float input[2] = {0.5f, -0.5f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp4_e2m1_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp4_e2m1_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 0.5f) + ret = 1; + if (out[1] != -0.5f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_exact_normals(sycl::queue &queue) { + const float input[2] = {6.0f, 1.0f}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp4_e2m1_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp4_e2m1_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 6.0f) + ret = 1; + if (out[1] != 1.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +int test_boundary_round_trip_saturation_and_infinity_clamp(sycl::queue &queue) { + // 100.0 saturates to +6.0; -inf saturates to -6.0 (no Inf in E2M1). + const float input[2] = {100.0f, -std::numeric_limits::infinity()}; + auto *data = sycl::malloc_shared(1, queue); + auto *out = sycl::malloc_shared(2, queue); + data[0] = fp4_e2m1_x2(input, rounding::to_even); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray unpacked = + static_cast>(value); + data[0] = fp4_e2m1_x2(unpacked, rounding::to_even); + sycl::marray round_tripped = + static_cast>(data[0]); + out[0] = round_tripped[0]; + out[1] = round_tripped[1]; + }); + queue.wait_and_throw(); + + int ret = 0; + if (out[0] != 6.0f) + ret = 1; + if (out[1] != -6.0f) + ret = 1; + + sycl::free(data, queue); + sycl::free(out, queue); + return ret; +} + +template int test_fp4_simple_type_conversion(sycl::queue &queue) { + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1_x2(static_cast(1.5f), static_cast(2.0f)); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(1.5f); // 1.5 + 1.5 = 3.0 (exactly representable) + f[1] += static_cast(2.0f); // 2.0 + 2.0 = 4.0 (exactly representable) + data[0] = fp4_e2m1_x2(f); + }); + queue.wait_and_throw(); + + sycl::marray expected_input(static_cast(3.0f), static_cast(4.0f)); + fp4_e2m1_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +template int test_marray_conversion(sycl::queue &queue) { + sycl::marray input(static_cast(1.0f), static_cast(2.0f)); + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1_x2(input); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray f = static_cast>(value); + f[0] += static_cast(2.0f); // 1+2=3 (exact) + f[1] += static_cast(2.0f); // 2+2=4 (exact) + data[0] = fp4_e2m1_x2(f); + }); + queue.wait_and_throw(); + sycl::marray expected_input(static_cast(3.0f), static_cast(4.0f)); + fp4_e2m1_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + return 0; +} + +template int test_carray_conversion(sycl::queue &queue) { + T input[2] = {static_cast(1.0f), static_cast(2.0f)}; + auto *data = sycl::malloc_shared(1, queue); + data[0] = fp4_e2m1_x2(input); + + queue.single_task([=]() { + fp4_e2m1_x2 value = data[0]; + sycl::marray unpacked = static_cast>(value); + T output[2] = {unpacked[0] + static_cast(2.0f), + unpacked[1] + static_cast(2.0f)}; + data[0] = fp4_e2m1_x2(output); + }); + queue.wait_and_throw(); + + T expected_input[2] = {static_cast(3.0f), static_cast(4.0f)}; + fp4_e2m1_x2 expected(expected_input); + sycl::marray out = static_cast>(data[0]); + sycl::marray expected_out = static_cast>(expected); + + sycl::free(data, queue); + for (size_t i = 0; i < 2; ++i) { + if (std::fabs(static_cast(out[i]) - + static_cast(expected_out[i])) > 0.0f) + return 1; + } + + return 0; +} + +int main() { + auto async_handler = [](sycl::exception_list exceptions) { + for (const std::exception_ptr &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (const sycl::exception &ex) { + std::cerr << "Async SYCL exception: " << ex.what() << '\n'; + std::terminate(); + } + } + }; + + sycl::queue queue{async_handler}; + + // fp4_e2m1_x2 only supports packed conversions through marray, + // marray, and marray. + int ret = test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + ret |= test_fp4_simple_type_conversion(queue); + + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + ret |= test_marray_conversion(queue); + + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + ret |= test_carray_conversion(queue); + + ret |= test_explicit_to_even_carray_constructor(queue); + ret |= test_explicit_to_even_carray_constructor(queue); + ret |= test_explicit_to_even_carray_constructor( + queue); + + ret |= test_explicit_to_even_marray_constructor(queue); + ret |= test_explicit_to_even_marray_constructor(queue); + ret |= test_explicit_to_even_marray_constructor( + queue); + + ret |= test_boundary_round_trip_negative_zero(queue); + ret |= test_boundary_round_trip_subnormals(queue); + ret |= test_boundary_round_trip_exact_normals(queue); + ret |= test_boundary_round_trip_saturation_and_infinity_clamp(queue); + return ret; +} diff --git a/sycl/test-e2e/Experimental/fp4/lit.local.cfg.py b/sycl/test-e2e/Experimental/fp4/lit.local.cfg.py new file mode 100644 index 0000000000000..060ad61974858 --- /dev/null +++ b/sycl/test-e2e/Experimental/fp4/lit.local.cfg.py @@ -0,0 +1,9 @@ +config.environment["NEOReadDebugKeys"] = "1" +config.environment["ProductFamilyOverride"] = "cri" +config.environment["HardwareInfoOverride"] = "1x8x8" +config.environment["SetCommandStreamReceiver"] = "2" +config.environment["RebuildPrecompiledKernels"] = "1" +config.environment["EnableDirectSubmission"] = "0" +config.environment["EnableBlitterOperationsSupport"] = "1" +config.environment["BlitterEnableMaskOverride"] = "6" +config.environment["Enable64BitAddressing"] = "1" diff --git a/sycl/unittests/Extensions/CMakeLists.txt b/sycl/unittests/Extensions/CMakeLists.txt index 63730d56ed088..d358e8a6eb853 100644 --- a/sycl/unittests/Extensions/CMakeLists.txt +++ b/sycl/unittests/Extensions/CMakeLists.txt @@ -37,3 +37,4 @@ add_subdirectory(KernelQueries) add_subdirectory(InterProcessCommunication) add_subdirectory(DeviceIndex) add_subdirectory(fp8) +add_subdirectory(fp4) diff --git a/sycl/unittests/Extensions/fp4/CMakeLists.txt b/sycl/unittests/Extensions/fp4/CMakeLists.txt new file mode 100644 index 0000000000000..5afdb35eb720d --- /dev/null +++ b/sycl/unittests/Extensions/fp4/CMakeLists.txt @@ -0,0 +1,3 @@ +add_sycl_unittest(FP4TypesTests OBJECT + fp4_e2m1.cpp +) diff --git a/sycl/unittests/Extensions/fp4/fp4_e2m1.cpp b/sycl/unittests/Extensions/fp4/fp4_e2m1.cpp new file mode 100644 index 0000000000000..0d2a768473db8 --- /dev/null +++ b/sycl/unittests/Extensions/fp4/fp4_e2m1.cpp @@ -0,0 +1,663 @@ +#include +#include + +#include +#include +#include +#include + +/* +Unit tests check only CPU versions. Most of the constraints related to device +code thus unit tests check only API +*/ + +using namespace sycl::ext::oneapi::experimental; + +// E2M1 encoding (S.EE.M, bias=1, no Inf/NaN): +// 0x0 = +0, 0x1 = +0.5, 0x2 = +1.0, 0x3 = +1.5, +// 0x4 = +2.0, 0x5 = +3.0, 0x6 = +4.0, 0x7 = +6.0 (max finite), +// negatives: bit 3 set. +// +// Packed x2 storage (one byte): element 0 in low 4 bits, element 1 in high 4 +// bits. + +TEST(FP4E2M1Test, DeductionGuide) { + fp4_e2m1_x one(1.0f); + fp4_e2m1_x pair(1.0f, 2.0f); + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(FP4E2M1Test, TrivialSpecialMembers) { + EXPECT_TRUE((std::is_trivially_default_constructible_v)); + EXPECT_TRUE((std::is_trivially_copy_constructible_v)); + EXPECT_TRUE((std::is_trivially_destructible_v)); + EXPECT_TRUE((std::is_trivially_copy_assignable_v)); + + EXPECT_TRUE((std::is_trivially_default_constructible_v)); + EXPECT_TRUE((std::is_trivially_copy_constructible_v)); + EXPECT_TRUE((std::is_trivially_destructible_v)); + EXPECT_TRUE((std::is_trivially_copy_assignable_v)); + + fp4_e2m1 source(1.0f); + fp4_e2m1 copy(source); + fp4_e2m1 assigned; + assigned = source; + + EXPECT_EQ(copy.vals[0], source.vals[0]); + EXPECT_EQ(assigned.vals[0], source.vals[0]); +} + +TEST(FP4E2M1Test, StorageSize) { + EXPECT_EQ(sizeof(fp4_e2m1{}.vals), 1u); + EXPECT_EQ(sizeof(fp4_e2m1_x2{}.vals), 1u); +} + +TEST(FP4E2M1Test, VariadicHalf) { + fp4_e2m1_x2 a(sycl::half(1.0f), sycl::half(2.0f)); + // element 0 -> low nibble = 0x2; element 1 -> high nibble = 0x4 + EXPECT_EQ(a.vals[0], 0x42); + + fp4_e2m1 b(sycl::half(1.5f)); + EXPECT_EQ(b.vals[0], 0x3); +} + +TEST(FP4E2M1Test, VariadicBFloat16) { + fp4_e2m1_x2 a(sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)); + EXPECT_EQ(a.vals[0], 0x42); + + fp4_e2m1 b(sycl::ext::oneapi::bfloat16(1.5f)); + EXPECT_EQ(b.vals[0], 0x3); +} + +TEST(FP4E2M1Test, VariadicFloat) { + fp4_e2m1_x2 a(1.0f, 2.0f); + EXPECT_EQ(a.vals[0], 0x42); + + fp4_e2m1 b(1.5f); + EXPECT_EQ(b.vals[0], 0x3); +} + +TEST(FP4E2M1Test, VariadicBoundaryEncodingsFloat) { + // Boundaries: max normal, min normal, max/min subnormal, +/-0. + fp4_e2m1_x2 a(6.0f, // max finite -> 0x7 + 1.0f); // min positive normal -> 0x2 + fp4_e2m1_x2 b(0.5f, // only positive subnormal -> 0x1 + 0.0f); // +0 -> 0x0 + fp4_e2m1_x2 c(0.0f, // +0 + -0.0f); + + EXPECT_EQ(a.vals[0], static_cast(0x7 | (0x2 << 4))); + EXPECT_EQ(b.vals[0], static_cast(0x1 | (0x0 << 4))); + EXPECT_EQ(c.vals[0], static_cast(0x0 | (0x8 << 4))); +} + +TEST(FP4E2M1Test, ScalarInfinityClampsToMaxNormalPreservingSign) { + // Spec: non-stochastic conversion clamps Infinity to max normal preserving + // sign (E2M1 has no Inf representation). + fp4_e2m1 pos(std::numeric_limits::infinity()); + fp4_e2m1 neg(-std::numeric_limits::infinity()); + + EXPECT_EQ(pos.vals[0], 0x7); // +6.0 + EXPECT_EQ(neg.vals[0], 0xF); // -6.0 + + EXPECT_EQ(static_cast(pos), 6.0f); + EXPECT_EQ(static_cast(neg), -6.0f); +} + +TEST(FP4E2M1Test, X2InfinityClampsToMaxNormalPreservingSign) { + const float in[2] = {std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + fp4_e2m1_x2 value(in); + + EXPECT_EQ(value.vals[0], static_cast(0x7 | (0xF << 4))); + + sycl::marray out = static_cast>(value); + EXPECT_EQ(out[0], 6.0f); + EXPECT_EQ(out[1], -6.0f); +} + +TEST(FP4E2M1Test, ScalarFiniteOverflowClampsToMaxNormalPreservingSign) { + fp4_e2m1 pos(1000.0f); + fp4_e2m1 neg(-1000.0f); + + EXPECT_EQ(pos.vals[0], 0x7); + EXPECT_EQ(neg.vals[0], 0xF); + + EXPECT_EQ(static_cast(pos), 6.0f); + EXPECT_EQ(static_cast(neg), -6.0f); +} + +TEST(FP4E2M1Test, X2FiniteOverflowClampsToMaxNormalPreservingSign) { + const float in[2] = {1000.0f, -1000.0f}; + fp4_e2m1_x2 value(in); + + EXPECT_EQ(value.vals[0], static_cast(0x7 | (0xF << 4))); + + sycl::marray out = static_cast>(value); + EXPECT_EQ(out[0], 6.0f); + EXPECT_EQ(out[1], -6.0f); +} + +TEST(FP4E2M1Test, NaNClampsToMaxNormalPreservingSign) { + // E2M1 has no NaN representation; non-stochastic conversion clamps. + float pos_nan = std::numeric_limits::quiet_NaN(); + float neg_nan = std::copysign(pos_nan, -1.0f); + + fp4_e2m1_x2 a(pos_nan, neg_nan); + EXPECT_EQ(a.vals[0], static_cast(0x7 | (0xF << 4))); +} + +TEST(FP4E2M1Test, IntegerToEvenAndSize) { + // Integer constructors: to_even (CPU host). + fp4_e2m1 a0(0); + fp4_e2m1 a1(1); + fp4_e2m1 a2(2); + fp4_e2m1 an1(-1); + fp4_e2m1 an2(-2); + + EXPECT_EQ(a0.vals[0], 0x0); // +0 + EXPECT_EQ(a1.vals[0], 0x2); // +1.0 + EXPECT_EQ(a2.vals[0], 0x4); // +2.0 + EXPECT_EQ(an1.vals[0], 0xA); // -1.0 + EXPECT_EQ(an2.vals[0], 0xC); // -2.0 +} + +TEST(FP4E2M1Test, IntegerOverflowClampsToMaxNormal) { + fp4_e2m1 big(100); + fp4_e2m1 nbig(-100); + + EXPECT_EQ(big.vals[0], 0x7); + EXPECT_EQ(nbig.vals[0], 0xF); +} + +TEST(FP4E2M1Test, AssignmentOperatorToEvenAndSize) { + fp4_e2m1 a(0.0f); + EXPECT_EQ(a.vals[0], 0x0); + + a = 1.0f; + EXPECT_EQ(a.vals[0], 0x2); + + a = -2.0f; + EXPECT_EQ(a.vals[0], 0xC); + + a = 0.5f; // only positive subnormal + EXPECT_EQ(a.vals[0], 0x1); +} + +TEST(FP4E2M1Test, AssignmentOperatorsAllScalarWidths) { + fp4_e2m1 value(1.0f); + + EXPECT_EQ(&(value = sycl::half(1.0f)), &value); + EXPECT_EQ(static_cast(value), 1.0f); + + EXPECT_EQ(&(value = sycl::ext::oneapi::bfloat16(-3.0f)), &value); + EXPECT_EQ(static_cast(value), -3.0f); + + EXPECT_EQ(&(value = 4.0f), &value); + EXPECT_EQ(static_cast(value), 4.0f); + + EXPECT_EQ(&(value = static_cast(2)), &value); + EXPECT_EQ(static_cast(value), 2.0f); + + EXPECT_EQ(&(value = -3), &value); + EXPECT_EQ(static_cast(value), -3.0f); + + EXPECT_EQ(&(value = 4L), &value); + EXPECT_EQ(static_cast(value), 4.0f); + + EXPECT_EQ(&(value = -2LL), &value); + EXPECT_EQ(static_cast(value), -2.0f); + + EXPECT_EQ(&(value = static_cast(3)), &value); + EXPECT_EQ(static_cast(value), 3.0f); + + EXPECT_EQ(&(value = 2U), &value); + EXPECT_EQ(static_cast(value), 2.0f); + + EXPECT_EQ(&(value = 4UL), &value); + EXPECT_EQ(static_cast(value), 4.0f); + + EXPECT_EQ(&(value = 1ULL), &value); + EXPECT_EQ(static_cast(value), 1.0f); +} + +TEST(FP4E2M1Test, FloatingPointConversionOperators) { + fp4_e2m1 one(1.0f); + fp4_e2m1 zero_pos(0.0f); + fp4_e2m1 zero_neg(-0.0f); + fp4_e2m1 sub(0.5f); // only positive subnormal + + EXPECT_EQ(one.vals[0], 0x2); + EXPECT_EQ(zero_pos.vals[0], 0x0); + EXPECT_EQ(zero_neg.vals[0], 0x8); + EXPECT_EQ(sub.vals[0], 0x1); + + EXPECT_EQ(static_cast(one), 1.0f); + EXPECT_EQ(static_cast(zero_pos), 0.0f); + + float fnz = static_cast(zero_neg); + EXPECT_EQ(fnz, 0.0f); + EXPECT_TRUE(std::signbit(fnz)); + + EXPECT_EQ(static_cast(sub), 0.5f); +} + +TEST(FP4E2M1Test, IntegerConversionOperatorsTowardZero) { + fp4_e2m1 p(1.5f); + fp4_e2m1 n(-1.5f); + + EXPECT_EQ(p.vals[0], 0x3); + EXPECT_EQ(n.vals[0], 0xB); + + int ip = static_cast(p); + int in = static_cast(n); + + EXPECT_EQ(ip, 1); + EXPECT_EQ(in, -1); +} + +TEST(FP4E2M1Test, BoolOperatorZeroRules) { + fp4_e2m1 zp(0.0f); + fp4_e2m1 zn(-0.0f); + fp4_e2m1 one(1.0f); + fp4_e2m1 sub(0.5f); + + EXPECT_FALSE(static_cast(zp)); + EXPECT_FALSE(static_cast(zn)); + EXPECT_TRUE(static_cast(one)); + EXPECT_TRUE(static_cast(sub)); +} + +TEST(FP4E2M1Test, CArrayFloatHostToEvenSaturating) { + const float in[2] = {1.0f, 1.25f}; + const float in1[2] = {1.0625f, 1000.0f}; + const float in2[2] = {-0.0f, 0.0f}; + fp4_e2m1_x2 a(in); + fp4_e2m1_x2 a1(in1); + fp4_e2m1_x2 a2(in2); + + // 1.25 is exactly between 1.0 (frac=0, even) and 1.5 (frac=1, odd). + // round-to-even -> 1.0 -> 0x2. + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x2 << 4))); + + // 1.0625 closer to 1.0 -> 0x2; 1000.0 -> +6.0 -> 0x7. + EXPECT_EQ(a1.vals[0], static_cast(0x2 | (0x7 << 4))); + + // -0, +0 + EXPECT_EQ(a2.vals[0], static_cast(0x8 | (0x0 << 4))); +} + +TEST(FP4E2M1Test, CArrayHalfHostToEvenSaturating) { + const sycl::half in[2] = {sycl::half(6.0f), sycl::half(7.0f)}; + const sycl::half in1[2] = {sycl::half(1.0f), sycl::half(0.5f)}; + const sycl::half in2[2] = {sycl::half(-0.0f), sycl::half(0.0f)}; + + fp4_e2m1_x2 a(in); + fp4_e2m1_x2 a1(in1); + fp4_e2m1_x2 a2(in2); + + EXPECT_EQ(a.vals[0], static_cast(0x7 | (0x7 << 4))); + EXPECT_EQ(a1.vals[0], static_cast(0x2 | (0x1 << 4))); + EXPECT_EQ(a2.vals[0], static_cast(0x8 | (0x0 << 4))); +} + +TEST(FP4E2M1Test, CArrayBFloat16HostToEvenSaturating) { + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(6.0f), sycl::ext::oneapi::bfloat16(7.0f)}; + const sycl::ext::oneapi::bfloat16 in1[2] = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(0.5f)}; + const sycl::ext::oneapi::bfloat16 in2[2] = { + sycl::ext::oneapi::bfloat16(-0.0f), sycl::ext::oneapi::bfloat16(0.0f)}; + + fp4_e2m1_x2 a(in); + fp4_e2m1_x2 a1(in1); + fp4_e2m1_x2 a2(in2); + + EXPECT_EQ(a.vals[0], static_cast(0x7 | (0x7 << 4))); + EXPECT_EQ(a1.vals[0], static_cast(0x2 | (0x1 << 4))); + EXPECT_EQ(a2.vals[0], static_cast(0x8 | (0x0 << 4))); +} + +TEST(FP4E2M1Test, MarrayAndOperatorsHostAllN) { + sycl::marray in = {1.0f, 2.0f}; + sycl::marray in1 = {0.0f, -0.0f}; + sycl::marray in2 = {6.0f, 1000.0f}; + sycl::marray in3 = {0.5f, -1.5f}; + + fp4_e2m1_x2 a(in); + fp4_e2m1_x2 a1(in1); + fp4_e2m1_x2 a2(in2); + fp4_e2m1_x2 a3(in3); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x4 << 4))); + EXPECT_EQ(a1.vals[0], static_cast(0x0 | (0x8 << 4))); + EXPECT_EQ(a2.vals[0], static_cast(0x7 | (0x7 << 4))); + EXPECT_EQ(a3.vals[0], static_cast(0x1 | (0xB << 4))); + + sycl::marray out = static_cast>(a); + sycl::marray out1 = static_cast>(a1); + sycl::marray out2 = static_cast>(a2); + sycl::marray out3 = static_cast>(a3); + EXPECT_EQ(out[0], 1.0f); + EXPECT_EQ(out[1], 2.0f); + EXPECT_EQ(out1[0], 0.0f); + EXPECT_EQ(out1[1], 0.0f); + EXPECT_TRUE(std::signbit(out1[1])); + EXPECT_EQ(out2[0], 6.0f); + EXPECT_EQ(out2[1], 6.0f); + EXPECT_EQ(out3[0], 0.5f); + EXPECT_EQ(out3[1], -1.5f); +} + +TEST(FP4E2M1Test, FloatingPointConversionOperatorsMoreTypes) { + fp4_e2m1 a(1.0f); + fp4_e2m1 b(0.5f); + + sycl::half ha = static_cast(a); + sycl::ext::oneapi::bfloat16 ba = static_cast(a); + + EXPECT_EQ(static_cast(ha), 1.0f); + EXPECT_EQ(static_cast(ba), 1.0f); + EXPECT_EQ(static_cast(b), 0.5f); +} + +TEST(FP4E2M1Test, MarrayConversionOperatorsHalfNumericValues) { + fp4_e2m1_x2 a(4.0f, -3.0f); + + sycl::marray out = static_cast>(a); + + EXPECT_EQ(static_cast(out[0]), 4.0f); + EXPECT_EQ(static_cast(out[1]), -3.0f); +} + +TEST(FP4E2M1Test, MarrayConversionOperatorsBFloat16NumericValues) { + fp4_e2m1_x2 a(-2.0f, 1.5f); + + sycl::marray out = + static_cast>(a); + + EXPECT_EQ(static_cast(out[0]), -2.0f); + EXPECT_EQ(static_cast(out[1]), 1.5f); +} + +TEST(FP4E2M1Test, IntegerConversionOperatorsMultipleWidthsTowardZero) { + fp4_e2m1 p(1.5f); + fp4_e2m1 n(-1.5f); + + int i = static_cast(p); + short s = static_cast(n); + long l = static_cast(p); + long long ll = static_cast(n); + + EXPECT_EQ(i, 1); + EXPECT_EQ(s, -1); + EXPECT_EQ(l, 1); + EXPECT_EQ(ll, -1); +} + +TEST(FP4E2M1Test, IntegerConversionOperatorsRemainingWidthsTowardZero) { + fp4_e2m1 pos_char(3.0f); + fp4_e2m1 neg_schar(-2.0f); + fp4_e2m1 pos_uchar(4.0f); + fp4_e2m1 pos_ushort(6.0f); + fp4_e2m1 pos_uint(2.0f); + fp4_e2m1 pos_ulong(4.0f); + fp4_e2m1 pos_ull(6.0f); + + char c = static_cast(pos_char); + signed char sc = static_cast(neg_schar); + unsigned char uc = static_cast(pos_uchar); + unsigned short us = static_cast(pos_ushort); + unsigned int ui = static_cast(pos_uint); + unsigned long ul = static_cast(pos_ulong); + unsigned long long ull = static_cast(pos_ull); + + EXPECT_EQ(c, static_cast(3)); + EXPECT_EQ(sc, static_cast(-2)); + EXPECT_EQ(uc, static_cast(4)); + EXPECT_EQ(us, static_cast(6)); + EXPECT_EQ(ui, 2u); + EXPECT_EQ(ul, 4ul); + EXPECT_EQ(ull, 6ull); +} + +TEST(FP4E2M1Test, CArrayFloatRoundingToEven) { + // 1.25 ties between 1.0 (even) and 1.5 (odd) -> to_even = 1.0. + // 1000 saturates to +6.0. + const float in[2] = {1.25f, 1000.0f}; + fp4_e2m1_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x7 << 4))); +} + +TEST(FP4E2M1Test, CArrayHalfRoundingToEven) { + const sycl::half in[2] = {sycl::half(1.25f), sycl::half(1000.0f)}; + fp4_e2m1_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x7 << 4))); +} + +TEST(FP4E2M1Test, CArrayBFloat16RoundingToEven) { + const sycl::ext::oneapi::bfloat16 in[2] = { + sycl::ext::oneapi::bfloat16(1.25f), + sycl::ext::oneapi::bfloat16(1000.0f)}; + fp4_e2m1_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x7 << 4))); +} + +TEST(FP4E2M1Test, MarrayHalfRoundingToEven) { + const sycl::marray in = {sycl::half(1.25f), sycl::half(2.5f)}; + fp4_e2m1_x2 a(in, rounding::to_even); + + // 2.5 ties between 2.0 (even) and 3.0 (odd) -> to_even = 2.0 -> 0x4. + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x4 << 4))); +} + +TEST(FP4E2M1Test, MarrayBFloat16RoundingToEven) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.25f), sycl::ext::oneapi::bfloat16(2.5f)}; + fp4_e2m1_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x4 << 4))); +} + +TEST(FP4E2M1Test, MarrayFloatRoundingToEven) { + const sycl::marray in = {1.25f, 2.5f}; + fp4_e2m1_x2 a(in, rounding::to_even); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x4 << 4))); +} + +TEST(FP4E2M1Test, VariadicRejectsMixedTypes) { + EXPECT_FALSE((std::is_constructible_v)); + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleLong) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleLL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleUShort) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleUInt) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleUL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleULL) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleFloat) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleHalf) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotConstructibleFromSingleUChar) { + EXPECT_FALSE((std::is_constructible_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleHalf) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleBFloat16) { + EXPECT_FALSE( + (std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleFloat) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleSignedChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleUChar) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleLong) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleLL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleUShort) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleUInt) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleUL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +TEST(FP4E2M1Test, X2NotAssignableFromSingleULL) { + EXPECT_FALSE((std::is_assignable_v)); +} + +#if LLVM_ENABLE_ASSERTIONS +TEST(FP4E2M1Test, CArrayHalfRejectsTowardZeroRounding) { + const sycl::half in[2] = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp4_e2m1_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp4_e2m1_x: only rounding::to_even is supported"); +} + +TEST(FP4E2M1Test, CArrayBFloat16RejectsTowardZeroRounding) { + const sycl::ext::oneapi::bfloat16 in[2] = {sycl::ext::oneapi::bfloat16(1.0f), + sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp4_e2m1_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp4_e2m1_x: only rounding::to_even is supported"); +} + +TEST(FP4E2M1Test, CArrayFloatRejectsTowardZeroRounding) { + const float in[2] = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp4_e2m1_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp4_e2m1_x: only rounding::to_even is supported"); +} + +TEST(FP4E2M1Test, MarrayHalfRejectsTowardZeroRounding) { + const sycl::marray in = {sycl::half(1.0f), sycl::half(2.0f)}; + EXPECT_DEATH( + { + fp4_e2m1_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp4_e2m1_x: only rounding::to_even is supported"); +} + +TEST(FP4E2M1Test, MarrayBFloat16RejectsTowardZeroRounding) { + const sycl::marray in = { + sycl::ext::oneapi::bfloat16(1.0f), sycl::ext::oneapi::bfloat16(2.0f)}; + EXPECT_DEATH( + { + fp4_e2m1_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp4_e2m1_x: only rounding::to_even is supported"); +} + +TEST(FP4E2M1Test, MarrayFloatRejectsTowardZeroRounding) { + const sycl::marray in = {1.0f, 2.0f}; + EXPECT_DEATH( + { + fp4_e2m1_x2 value(in, rounding::toward_zero); + (void)value; + }, + "fp4_e2m1_x: only rounding::to_even is supported"); +} +#endif // LLVM_ENABLE_ASSERTIONS + +TEST(FP4E2M1Test, VariadicFloatReferences) { + float x = 1.0f; + float y = 2.0f; + float &xf = x; + float &yf = y; + + fp4_e2m1_x2 a(xf, yf); + + EXPECT_EQ(a.vals[0], static_cast(0x2 | (0x4 << 4))); +} From c7d6244694a5168bc8d0b67860258d0eac8c41ad Mon Sep 17 00:00:00 2001 From: Dmitry Vodopyanov Date: Thu, 25 Jun 2026 16:37:02 +0200 Subject: [PATCH 2/2] Update --- .../oneapi/experimental/float_4bit/types.hpp | 281 ++++++++---------- 1 file changed, 117 insertions(+), 164 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp b/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp index e7060fab668d9..d4f19da0aa97e 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/float_4bit/types.hpp @@ -8,19 +8,20 @@ #pragma once +#ifdef __SYCL_DEVICE_ONLY__ #include +#include #include +#endif +#include #include -#include #include #include #include #include -#include #include -#include #include #ifdef __SYCL_DEVICE_ONLY__ @@ -220,8 +221,7 @@ static inline uint8_t ConvertFloatToFP4_CPU(T f, rounding R) noexcept { << Traits::FracBits; constexpr UInt ExpAllOnes = (UInt{1} << Traits::ExpBits) - UInt{1}; - UInt bits; - std::memcpy(&bits, &f, sizeof(bits)); + UInt bits = sycl::bit_cast(f); const uint8_t sign = (bits & SignMask) ? 0x8u : 0x0u; bits &= ~SignMask; @@ -335,64 +335,48 @@ struct Fp4HasFloatTraits::Bias)>> : std::true_type {}; +// E2M1 has only 16 representable values; a lookup table is faster and +// simpler than reconstructing the destination bits for each call. Index is +// the E2M1 nibble (sign bit at 0x8, exponent at 0x6, fraction at 0x1). +static constexpr float kFp4ToFloatTable[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, +}; + // CPU host conversion: E2M1 nibble (low 4 bits of `code`) to ToT. template static inline ToT ConvertFromFP4ToBinaryFloat_CPU(uint8_t code, rounding R) noexcept { using Format = FP4E2M1Traits; - constexpr uint8_t SignBit = 0x8u; - constexpr uint8_t ExpAllOnes = Format::ExpAllOnes; - constexpr uint8_t FracMask = Format::MaxFrac; - - const bool negative = (code & SignBit) != 0u; - const uint8_t exp = static_cast((code >> Format::Mbits) & ExpAllOnes); - const uint8_t frac = static_cast(code & FracMask); - - uint32_t significand = 0u; - int exp2 = 0; - - if (exp == 0u) { - if (frac == 0u) { - significand = 0u; - } else { - significand = frac; - exp2 = Format::Emin; - } - } else { - significand = - static_cast((1u << Format::Mbits) + frac); - exp2 = static_cast(exp) - Format::Bias; - } if constexpr (Fp4HasFloatTraits::value) { - using Traits = Fp4SourceTraits; - using UInt = typename Traits::UInt; + (void)R; + return static_cast(kFp4ToFloatTable[code & 0x0Fu]); + } else if constexpr (std::is_integral_v) { + constexpr uint8_t SignBit = 0x8u; + constexpr uint8_t ExpAllOnes = Format::ExpAllOnes; + constexpr uint8_t FracMask = Format::MaxFrac; - constexpr UInt ExpAllOnesDst = ((UInt{1} << Traits::ExpBits) - UInt{1}) - << Traits::FracBits; - constexpr UInt FracMaskDst = (UInt{1} << Traits::FracBits) - UInt{1}; + const bool negative = (code & SignBit) != 0u; + const uint8_t exp = + static_cast((code >> Format::Mbits) & ExpAllOnes); + const uint8_t frac = static_cast(code & FracMask); - UInt bits = 0; - if (significand == 0u) { - bits = - negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; + uint32_t significand = 0u; + int exp2 = 0; + + if (exp == 0u) { + if (frac == 0u) { + significand = 0u; + } else { + significand = frac; + exp2 = Format::Emin; + } } else { - const int sigBits = Fp4BitWidth(significand); - const int unbiasedExp = exp2 + sigBits - 1 - Format::Mbits; - const UInt signBit = - negative ? (UInt{1} << (Traits::ExpBits + Traits::FracBits)) : 0u; - - const int shift = static_cast(Traits::FracBits) - (sigBits - 1); - const UInt aligned = static_cast(significand) << shift; - const UInt expField = static_cast(unbiasedExp + Traits::Bias) - << Traits::FracBits; - bits = signBit | expField | (aligned & FracMaskDst); + significand = static_cast((1u << Format::Mbits) + frac); + exp2 = static_cast(exp) - Format::Bias; } - (void)R; - (void)ExpAllOnesDst; - return __builtin_bit_cast(ToT, bits); - } else if constexpr (std::is_integral_v) { using Traits = Fp4IntSourceTraits; using UnsignedT = typename Traits::UnsignedT; @@ -443,6 +427,9 @@ static inline uint8_t Fp4Extract(uint8_t packed, size_t i) noexcept { return static_cast((packed >> (i * 4)) & 0x0Fu); } +// Always-false used to defer static_assert until template instantiation. +template inline constexpr bool kFp4StochasticHostFalse = false; + } // namespace detail template class fp4_e2m1_x { @@ -452,17 +439,7 @@ template class fp4_e2m1_x { template >>> uint8_t ConvertToFP4(T h) { -#ifdef __SYCL_DEVICE_ONLY__ - if constexpr (std::is_same_v, char> || - std::is_same_v, signed char> || - std::is_same_v, unsigned char>) { - const _Float16 v = static_cast<_Float16>(h); - return __builtin_spirv_ClampConvertFP16ToE2M1INTEL(v); - } - return detail::ConvertIntToFP4_CPU(h, rounding::to_even); -#else return detail::ConvertIntToFP4_CPU(h, rounding::to_even); -#endif } uint8_t ConvertToFP4(sycl::half h) { @@ -499,7 +476,12 @@ template class fp4_e2m1_x { #endif } + // Decode an E2M1 nibble to half or float. Only `half` and `float` use this; + // bfloat16 has its own `ConvertBF16FromFP4` so it can take the BF16 builtin + // directly instead of routing through FP16. template T ConvertFromFP4(uint8_t v) const { + static_assert(std::is_same_v || std::is_same_v, + "ConvertFromFP4: T must be sycl::half or float"); #ifdef __SYCL_DEVICE_ONLY__ sycl::half hi = __builtin_spirv_ConvertE2M1ToFP16INTEL(v); return static_cast(hi); @@ -551,21 +533,72 @@ template class fp4_e2m1_x { #endif } - void CheckConstraints(rounding r) const { + static constexpr void CheckConstraints(rounding r) { + // `rounding` is a compile-time enum; ctors always pass a literal. + // The runtime check is left in place for the unlikely caller that + // computes the value, but only `to_even` is supported. assert(r == rounding::to_even && "fp4_e2m1_x: only rounding::to_even is supported"); } - // Store one nibble at element index i (0 or 1). - void StoreNibble(size_t i, uint8_t nibble) { - if (i == 0) - vals[0] = static_cast((vals[0] & 0xF0u) | (nibble & 0x0Fu)); + // Pack two scalar floats (or marray elements) into vals[0]. + // Used by the float-only ctors which cannot use the FP16/BF16 vec2 + // builtins. + void PackFloats(const float *v) { + if constexpr (N == 1) { + vals[0] = ConvertToFP4(v[0]); + } else { + const uint8_t lo = ConvertToFP4(v[0]); + const uint8_t hi = ConvertToFP4(v[1]); + vals[0] = detail::Fp4Pack(lo, hi); + } + } + +#ifdef __SYCL_DEVICE_ONLY__ + // Stochastic rounding loops shared by the array and marray ctors. + // The seed referenced by `seed.pseed` is updated to the final next-seed + // value as required by the spec. + void StochasticFromHalf(const half *in, const stochastic_seed &seed) { + uint32_t current_seed = *seed.pseed; + uint32_t next_seed; + uint8_t nibbles[2] = {0, 0}; + for (size_t i = 0; i < N; ++i) { + const _Float16 v = sycl::bit_cast<_Float16>(in[i]); + nibbles[i] = __builtin_spirv_StochasticRoundFP16ToE2M1INTEL( + v, current_seed, + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); + current_seed = next_seed; + } + *seed.pseed = current_seed; + if constexpr (N == 1) + vals[0] = static_cast(nibbles[0] & 0x0Fu); else - vals[0] = static_cast((vals[0] & 0x0Fu) | - (static_cast(nibble & 0x0Fu) - << 4)); + vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); } + void StochasticFromBFloat16(const bfloat16 *in, + const stochastic_seed &seed) { + uint32_t current_seed = *seed.pseed; + uint32_t next_seed; + uint8_t nibbles[2] = {0, 0}; + for (size_t i = 0; i < N; ++i) { + nibbles[i] = __builtin_spirv_StochasticRoundBF16ToE2M1INTEL( + sycl::bit_cast<__bf16>(in[i]), current_seed, + sycl::khr::static_addrspace_cast< + sycl::access::address_space::private_space>(&next_seed) + .get_decorated()); + current_seed = next_seed; + } + *seed.pseed = current_seed; + if constexpr (N == 1) + vals[0] = static_cast(nibbles[0] & 0x0Fu); + else + vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); + } +#endif // __SYCL_DEVICE_ONLY__ + #ifdef __SYCL_DEVICE_ONLY__ #define CONVERT_TO_FP4(VecType, CastType, in, Prefix) \ if constexpr (N == 1) { \ @@ -609,13 +642,7 @@ template class fp4_e2m1_x { CONVERT_TO_FP4(::sycl::detail::fp4_float16_vec2, _Float16, in, ); } else { const float in[N] = {v...}; - if constexpr (N == 1) { - vals[0] = ConvertToFP4(in[0]); - } else { - const uint8_t lo = ConvertToFP4(in[0]); - const uint8_t hi = ConvertToFP4(in[1]); - vals[0] = detail::Fp4Pack(lo, hi); - } + PackFloats(in); } } @@ -633,13 +660,7 @@ template class fp4_e2m1_x { explicit fp4_e2m1_x(float const (&v)[N], rounding r = rounding::to_even) { CheckConstraints(r); - if constexpr (N == 1) { - vals[0] = ConvertToFP4(v[0]); - } else { - const uint8_t lo = ConvertToFP4(v[0]); - const uint8_t hi = ConvertToFP4(v[1]); - vals[0] = detail::Fp4Pack(lo, hi); - } + PackFloats(&v[0]); } // Construct from an marray of half, bfloat16, float. @@ -658,64 +679,27 @@ template class fp4_e2m1_x { explicit fp4_e2m1_x(const sycl::marray &v, rounding r = rounding::to_even) { CheckConstraints(r); - if constexpr (N == 1) { - vals[0] = ConvertToFP4(v[0]); - } else { - const uint8_t lo = ConvertToFP4(v[0]); - const uint8_t hi = ConvertToFP4(v[1]); - vals[0] = detail::Fp4Pack(lo, hi); - } + PackFloats(&v[0]); } // Construct with stochastic rounding from an array of half, bfloat16. explicit fp4_e2m1_x([[maybe_unused]] half const (&in)[N], [[maybe_unused]] const stochastic_seed &seed) { #ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - uint32_t next_seed = 0; - uint8_t nibbles[2] = {0, 0}; - for (size_t i = 0; i < N; ++i) { - const _Float16 v = sycl::bit_cast<_Float16>(in[i]); - nibbles[i] = __builtin_spirv_StochasticRoundFP16ToE2M1INTEL( - v, current_seed, - sycl::khr::static_addrspace_cast< - sycl::access::address_space::private_space>(&next_seed) - .get_decorated()); - current_seed = next_seed; - next_seed = 0; - } - if constexpr (N == 1) - vals[0] = static_cast(nibbles[0] & 0x0Fu); - else - vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); + StochasticFromHalf(&in[0], seed); #else - throw std::runtime_error( - "stochastic rounding constructors are not supported on host"); + static_assert(detail::kFp4StochasticHostFalse, + "stochastic rounding constructors are not supported on host"); #endif } explicit fp4_e2m1_x([[maybe_unused]] bfloat16 const (&in)[N], [[maybe_unused]] const stochastic_seed &seed) { #ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - uint32_t next_seed = 0; - uint8_t nibbles[2] = {0, 0}; - for (size_t i = 0; i < N; ++i) { - nibbles[i] = __builtin_spirv_StochasticRoundBF16ToE2M1INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::khr::static_addrspace_cast< - sycl::access::address_space::private_space>(&next_seed) - .get_decorated()); - current_seed = next_seed; - next_seed = 0; - } - if constexpr (N == 1) - vals[0] = static_cast(nibbles[0] & 0x0Fu); - else - vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); + StochasticFromBFloat16(&in[0], seed); #else - throw std::runtime_error( - "stochastic rounding constructors are not supported on host"); + static_assert(detail::kFp4StochasticHostFalse, + "stochastic rounding constructors are not supported on host"); #endif } @@ -723,51 +707,20 @@ template class fp4_e2m1_x { explicit fp4_e2m1_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed) { #ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - uint32_t next_seed = 0; - uint8_t nibbles[2] = {0, 0}; - for (size_t i = 0; i < N; ++i) { - const _Float16 v = sycl::bit_cast<_Float16>(in[i]); - nibbles[i] = __builtin_spirv_StochasticRoundFP16ToE2M1INTEL( - v, current_seed, - sycl::khr::static_addrspace_cast< - sycl::access::address_space::private_space>(&next_seed) - .get_decorated()); - current_seed = next_seed; - next_seed = 0; - } - if constexpr (N == 1) - vals[0] = static_cast(nibbles[0] & 0x0Fu); - else - vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); + StochasticFromHalf(&in[0], seed); #else - throw std::runtime_error( - "stochastic rounding constructors are not supported on host"); + static_assert(detail::kFp4StochasticHostFalse, + "stochastic rounding constructors are not supported on host"); #endif } explicit fp4_e2m1_x([[maybe_unused]] const sycl::marray &in, [[maybe_unused]] const stochastic_seed &seed) { #ifdef __SYCL_DEVICE_ONLY__ - uint32_t current_seed = *seed.pseed; - uint32_t next_seed = 0; - uint8_t nibbles[2] = {0, 0}; - for (size_t i = 0; i < N; ++i) { - nibbles[i] = __builtin_spirv_StochasticRoundBF16ToE2M1INTEL( - sycl::bit_cast<__bf16>(in[i]), current_seed, - sycl::khr::static_addrspace_cast< - sycl::access::address_space::private_space>(&next_seed) - .get_decorated()); - current_seed = next_seed; - next_seed = 0; - } - if constexpr (N == 1) - vals[0] = static_cast(nibbles[0] & 0x0Fu); - else - vals[0] = detail::Fp4Pack(nibbles[0], nibbles[1]); + StochasticFromBFloat16(&in[0], seed); #else - throw std::runtime_error( - "stochastic rounding constructors are not supported on host"); + static_assert(detail::kFp4StochasticHostFalse, + "stochastic rounding constructors are not supported on host"); #endif }