diff --git a/README.md b/README.md index 2fad7de2..def5848f 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,23 @@ Tile IR operations. | Operation | Description | |-----------|-------------| | `a * b` | Matrix multiplication: `a @ b` | -| `muladd(a, b, acc)` | Matrix multiply-accumulate: `a * b + acc` | +| `muladd(a, b, acc; fast_acc=false)` | Matrix multiply-accumulate: `a * b + acc` | +| `ct.muladd_scaled(a, a_scale, b, b_scale, acc)` | Block-scaled multiply-accumulate | + +Each operation follows `Base.:*` / `Base.muladd`'s shape rules, with the addition of allowing trailing batch dimensions. + +`fast_acc=true` enables fast accumulation for FP8 inputs, and has an effect only on Hopper (sm_90; silently ignored on other +architectures), and requires Tile IR v13.3+. + +`ct.muladd_scaled` multiplies each operand by a low-precision block scale before the matmul: +each scale element covers a contiguous block of `B = K ÷ K_s` elements along the K dimension. +Requires Blackwell. The supported operand/scale/accumulator dtypes and block sizes are: + +| Input (`a`/`b`) | Scale | Acc/Output | B | +|-----------------|-------|------------|--------| +| `Float8_E4M3FN`, `Float8_E5M2` | `Float8_E8M0FNU` | `Float32` | 32 | +| `Float4_E2M1FN` | `Float8_E8M0FNU` | `Float32` | 16, 32 | +| `Float4_E2M1FN` | `Float8_E4M3FN` | `Float32` | 16 | ### Higher-Order Functions | Operation | Description | diff --git a/ext/DLFP8TypesExt.jl b/ext/DLFP8TypesExt.jl index bd226633..841b44c3 100644 --- a/ext/DLFP8TypesExt.jl +++ b/ext/DLFP8TypesExt.jl @@ -12,6 +12,16 @@ function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) return ct.F8E5M2(table) end +# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2 +# operands with an f16 or f32 accumulator (f16 first/preferred), mirroring +# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. +ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) +ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) + +# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint. +ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true +ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true + # Float ↔ FP8 scalar constructor overlays (for map/convert dispatch) const FP8Types = (Float8_E4M3FN, Float8_E5M2) const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64) diff --git a/ext/MicrofloatsExt.jl b/ext/MicrofloatsExt.jl index 230cc7f9..09c7b812 100644 --- a/ext/MicrofloatsExt.jl +++ b/ext/MicrofloatsExt.jl @@ -23,6 +23,19 @@ ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth( # nearest-even on f32→E8M0FNU (only `zero` and `positive_inf` are valid). ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero +# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2 +# operands with an f16 or f32 accumulator — f16 first/preferred, mirroring +# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. The +# other Microfloats formats (E8M0FNU, Float4_E2M1FN) are only valid as +# scaled-mma operands/scales, so they stay absent and fall through to `mma`'s +# unsupported-dtype error. +ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) +ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) + +# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint. +ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true +ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true + # Float ↔ microfloat scalar constructor overlays (for map/convert dispatch). # Mirrors DLFP8TypesExt: route to `Intrinsics.ftof` so kernel-side conversions # lower to the FToFOp Tile IR intrinsic instead of Microfloats' Float32-fallback diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 81946912..fddf37cc 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -98,6 +98,8 @@ module Opcode const Atan2Op = 110 # since 13.2 const PackOp = 111 # since 13.3 const UnpackOp = 112 # since 13.3 + # 113 (AllocaOp) not implemented + const MmaFScaledOp = 114 # since 13.3 end # Enums for operation attributes @@ -1310,6 +1312,30 @@ function encode_MmaIOp!(cb::CodeBuilder, result_type::TypeId, return new_op!(cb) end +""" + encode_MmaFScaledOp!(cb, result_type, lhs, rhs, acc, lhs_scale, rhs_scale) -> Value + +Block-scaled float matrix multiply-accumulate (`acc + (lhs ⊙ lhs_scale) @ +(rhs ⊙ rhs_scale)`) for low-precision (f8/f4) inputs. Each scale element scales +a contiguous block of K-dimension elements in its operand. Requires Tile IR +v13.3+. +Opcode: 114 +""" +function encode_MmaFScaledOp!(cb::CodeBuilder, result_type::TypeId, + lhs::Value, rhs::Value, acc::Value, + lhs_scale::Value, rhs_scale::Value) + cb.version >= v"13.3" || + throw(IRError("MmaFScaledOp requires Tile IR v13.3+, got v$(cb.version)")) + encode_varint!(cb.buf, Opcode.MmaFScaledOp) + encode_typeid!(cb.buf, result_type) + encode_operand!(cb.buf, lhs) + encode_operand!(cb.buf, rhs) + encode_operand!(cb.buf, acc) + encode_operand!(cb.buf, lhs_scale) + encode_operand!(cb.buf, rhs_scale) + return new_op!(cb) +end + #============================================================================= Integer arithmetic operations =============================================================================# diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 6ebc6058..0aeddcbc 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -415,6 +415,17 @@ mma_allowed_acc_dtypes(::Type{TFloat32}) = (Float32,) mma_allowed_acc_dtypes(::Type{Float64}) = (Float64,) mma_allowed_acc_dtypes(@nospecialize(::Type)) = nothing +""" + mma_supports_fast_acc(input_T) -> Bool + +Whether `mma`'s `fast_acc` hint (lower-precision accumulation for throughput) +is valid for operand element type `input_T`. Only the FP8 dtypes qualify; +extensions (Microfloats/DLFP8Types) register them by overloading this, like +[`mma_allowed_acc_dtypes`](@ref). Mirrors cuTile Python's +`use_fast_acc is only supported for fp8 input dtypes` check. +""" +mma_supports_fast_acc(@nospecialize(::Type)) = false + # First (preferred) accumulator dtype for a given input dtype. Used by # matmul to pick `acc = zeros(first_allowed_acc(T), …)` so that the input # dtype constraint and tileiras's acc-dtype constraint stay consistent @@ -425,7 +436,7 @@ first_allowed_acc_dtype(::Type{T}) where {T} = end """ - Intrinsics.mma(a::Tile, b::Tile, acc::Tile) -> typeof(acc) + Intrinsics.mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false) -> typeof(acc) Matrix-multiply-accumulate computing `a*b + acc`. Dispatches at codegen based on element types: @@ -436,15 +447,23 @@ based on element types: - `i8` `a`/`b` with `i32` `acc` lower to `cuda_tile.mmai`. Per-input signedness is derived from the Julia type (`Int8` → signed, `UInt8` → unsigned); `acc` and the result are always signed `i32`. + +`fast_acc` enables fast accumulation (trading accumulator precision for +throughput). It is only valid for FP8 inputs (see +[`mma_supports_fast_acc`](@ref)) and requires Tile IR v13.3+. """ -@intrinsic mma(a::Tile, b::Tile, acc::Tile) -tfunc(𝕃, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc)) = CC.widenconst(acc) +@intrinsic mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false) +tfunc(𝕃, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc), + @nospecialize(rest...)) = CC.widenconst(acc) function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) cb = ctx.cb lhs = emit_value!(ctx, args[1]) rhs = emit_value!(ctx, args[2]) acc = emit_value!(ctx, args[3]) + fast_acc = length(args) >= 4 ? + (@something get_constant(ctx, args[4]) throw(IRError("mma: fast_acc must be a compile-time constant"))) : + false (lhs === nothing || rhs === nothing || acc === nothing) && throw(IRError("Cannot resolve operands for mma()")) @@ -458,15 +477,24 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) # the table in cuTile Python's mma implementation. lhs_elem === rhs_elem || throw(IRError("mma: float lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem")) - allowed_acc = mma_allowed_acc_dtypes(lhs_elem) + # `invokelatest` so extension-defined acc-dtype tables (e.g. FP8 via the + # Microfloats/DLFP8Types exts) are visible from codegen's world age, + # mirroring `lookup_bitwidth`. + allowed_acc = Base.invokelatest(mma_allowed_acc_dtypes, lhs_elem) allowed_acc === nothing && throw(IRError("mma: unsupported float input dtype $lhs_elem")) acc_elem in allowed_acc || throw(IRError("mma: acc dtype $acc_elem is not allowed for input dtype " * "$lhs_elem; tileiras requires acc ∈ $allowed_acc")) - encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v) + # fast_acc is an FP8-only throughput hint; reject it elsewhere with a + # clear error rather than emitting IR tileiras would reject. + fast_acc && !Base.invokelatest(mma_supports_fast_acc, lhs_elem) && + throw(IRError("mma: fast_acc is only supported for fp8 input dtypes " * + "(f8e4m3fn, f8e5m2), got $lhs_elem")) + encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v; fast_acc) elseif lhs_elem <: Union{Int8, UInt8} && rhs_elem <: Union{Int8, UInt8} && acc_elem === Int32 + fast_acc && throw(IRError("mma: fast_acc is not supported for integer (mmai) inputs")) s_lhs = lhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned s_rhs = rhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned encode_MmaIOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v; @@ -480,6 +508,57 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) CGVal(result, acc.type_id, acc.jltype, acc.shape) end +""" + Intrinsics.mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc) -> typeof(acc) + +Block-scaled matrix-multiply-accumulate computing `(lhs ⊙ lhs_scale) * (rhs ⊙ +rhs_scale) + acc`, where each scale element multiplies a contiguous block of +`lhs`/`rhs` elements along the K dimension. Lowers to `cuda_tile.mmaf_scaled` +(Tile IR v13.3+). + +`lhs`/`rhs` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`), +`lhs_scale`/`rhs_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc`/result are +`f32`. The block size `K ÷ K_s` and the exact (operand, scale) dtype pairing are +validated by tileiras (see its `mmaf_scaled` verifier for the supported table). +""" +@intrinsic mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc) +tfunc(𝕃, ::typeof(Intrinsics.mma_scaled), @nospecialize(lhs), @nospecialize(lhs_scale), + @nospecialize(rhs), @nospecialize(rhs_scale), @nospecialize(acc)) = CC.widenconst(acc) +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma_scaled), args) + cb = ctx.cb + + lhs = emit_value!(ctx, args[1]) + lhs_scale = emit_value!(ctx, args[2]) + rhs = emit_value!(ctx, args[3]) + rhs_scale = emit_value!(ctx, args[4]) + acc = emit_value!(ctx, args[5]) + + (lhs === nothing || lhs_scale === nothing || rhs === nothing || + rhs_scale === nothing || acc === nothing) && + throw(IRError("Cannot resolve operands for mma_scaled()")) + + lhs_elem = eltype(CC.widenconst(lhs.jltype)) + rhs_elem = eltype(CC.widenconst(rhs.jltype)) + lhs_scale_elem = eltype(CC.widenconst(lhs_scale.jltype)) + rhs_scale_elem = eltype(CC.widenconst(rhs_scale.jltype)) + acc_elem = eltype(CC.widenconst(acc.jltype)) + + # Structural invariants (cheap, with clear errors); the operand/scale dtype + # pairing and block-size constraints are checked by the tileiras verifier, + # whose messages are already precise. + lhs_elem === rhs_elem || + throw(IRError("mma_scaled: lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem")) + lhs_scale_elem === rhs_scale_elem || + throw(IRError("mma_scaled: lhs_scale and rhs_scale must share dtype, " * + "got $lhs_scale_elem and $rhs_scale_elem")) + acc_elem === Float32 || + throw(IRError("mma_scaled: acc must be Float32, got $acc_elem")) + + result = encode_MmaFScaledOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v, + lhs_scale.v, rhs_scale.v) + CGVal(result, acc.type_id, acc.jltype, acc.shape) +end + # TODO: cuda_tile.module """ diff --git a/src/language/operations.jl b/src/language/operations.jl index db82f526..3a729b9c 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -1295,45 +1295,60 @@ end Matrix multiplication =============================================================================# -# Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc -# Handles 1D promotion, type promotion, and batched dims (≥3D). -# Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen -@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} - _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b))) +""" + muladd(a::Tile, b::Tile, acc::Tile; fast_acc::Bool=false) -> Tile + +Matrix multiply-accumulate `a * b + acc` over tiles, lowering to +`cuda_tile.mmaf` (float) or `cuda_tile.mmai` (`i8 × i8 → i32`). + +`a`/`b` are 2-D matrices `(M, K)` × `(K, N)` → `(M, N)`; a 1-D operand is +promoted (vec-mat / mat-vec) and any trailing dimensions (≥3-D) are treated as +broadcast batch dims, lifting `Base.muladd`'s shape rules to tiles. `acc` +carries the result dtype, which must be one tileiras allows for the input dtype +(f16/f32 for f16 and f8; f32 for bf16/tf32; f64 for f64; i32 for i8). + +`fast_acc` enables fast accumulation (lower accumulator precision for +throughput); it is valid only for FP8 inputs and requires Tile IR v13.3+. +""" +@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}; + fast_acc::Bool=false) where {T1, T2, T3, SA, SB, SC} + # SA, SB, SC type parameters avoid ambiguity with the scalar `muladd` + # methods during codegen. + _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)), fast_acc) end # 2D × 2D: MmaFOp with swapped operands for row-major Tile IR # Julia (M,K)*(K,N) → TileIR (K,M)*(N,K) → mmaf(b,a,acc) → TileIR (N,M) → Julia (M,N) -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}) - Intrinsics.mma(b, a, acc) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool) + Intrinsics.mma(b, a, acc, fast_acc) end # Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N) -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool) a2d = reshape(a, (size(a, 1), 1)) - _muladd(a2d, b, acc, Val(2), Val(2)) + _muladd(a2d, b, acc, Val(2), Val(2), fast_acc) end # Mat-vec (2D × 1D): reshape b (K,) → (K, 1), acc (M,) → (M, 1), MmaFOp, squeeze back -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool) M, K = size(a, 1), size(b, 1) b2d = reshape(b, (K, 1)) acc2d = reshape(acc, (M, 1)) - result = _muladd(a, b2d, acc2d, Val(2), Val(2)) + result = _muladd(a, b2d, acc2d, Val(2), Val(2), fast_acc) reshape(result, (M,)) end # Vec-vec (1D × 1D): not supported -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool) return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported."))) end # Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported, unsqueeze manually -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB} NB >= 3 || return :(throw(ArgumentError("unreachable"))) return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first."))) end -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA} NA >= 3 || return :(throw(ArgumentError("unreachable"))) return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first."))) end @@ -1346,7 +1361,7 @@ end # 3. MmaFOp with swapped operands: mmaf(b, a, acc) # 4. Unflatten batch dims via reshape @generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}, - ::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB} + ::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB} sa = Tuple(SA.parameters) sb = Tuple(SB.parameters) @@ -1373,12 +1388,125 @@ end b_3d = reshape(b_bc, $((K, N, B_flat))) acc_3d = reshape(acc_bc, $((M, N, B_flat))) # MmaFOp with swapped operands for row-major convention - result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d) + result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc) # Unflatten batch dims reshape(result_3d, $((M, N, batch_shape...))) end end +#============================================================================= + Block-scaled matrix multiply-accumulate +=============================================================================# + +public muladd_scaled + +""" + muladd_scaled(a, a_scale, b, b_scale, acc) -> Tile + +Block-scaled matrix multiply-accumulate `(a ⊙ a_scale) * (b ⊙ b_scale) + acc`, +lowering to `cuda_tile.mmaf_scaled` (Tile IR v13.3+, Blackwell). Each scale +element multiplies a contiguous block of `B = K ÷ K_s` elements along the K +dimension of its operand, so `a_scale`/`b_scale` match `a`/`b` in every +dimension except K, where they have `K_s ≤ K` entries. + +`a`/`b` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`), +`a_scale`/`b_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc` is `f32`. Shapes +follow [`muladd`](@ref): 2-D `(M, K)` × `(K, N)`, mat-vec, and trailing batch +dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale). +""" +@inline function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile, + acc::Tile) where {Ta, Tb, SA, SB} + _muladd_scaled(a, a_scale, b, b_scale, acc, Val(ndims(a)), Val(ndims(b))) +end + +# 2D × 2D: swap operands (and their scales) for row-major Tile IR, exactly as +# `_muladd` swaps for `mma`. +@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, + ::Val{2}, ::Val{2}) + Intrinsics.mma_scaled(b, b_scale, a, a_scale, acc) +end + +# Mat-vec (2D × 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim; +# `acc` becomes (M, 1); then squeeze back to (M,). K — the scaled dimension — +# is preserved, so block scaling is well-defined. +@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, + ::Val{2}, ::Val{1}) + M, K, Ks = size(a, 1), size(b, 1), size(b_scale, 1) + b2d = reshape(b, (K, 1)) + b_scale2d = reshape(b_scale, (Ks, 1)) + acc2d = reshape(acc, (M, 1)) + result = _muladd_scaled(a, a_scale, b2d, b_scale2d, acc2d, Val(2), Val(2)) + reshape(result, (M,)) +end + +# Vec-mat (1D × 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K +# dimension to block-scale. Unsupported — reshape to 2D and supply a matching +# K_s scale instead. +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2}) + return :(throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)."))) +end + +# Vec-vec (1D × 1D): not supported. +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) + return :(throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported."))) +end + +# Batched mat-vec / vec-mat (≥3D × 1D or 1D × ≥3D): not supported. +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} + return :(throw(ArgumentError("Batched scaled vec-mat is not supported."))) +end +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} + return :(throw(ArgumentError("Batched scaled mat-vec is not supported."))) +end + +# Batched (≥3D × ≥3D): trailing batch dims with broadcast, mirroring `_muladd`. +# Scales carry the same batch dims as their operands; a_scale's batch must match +# a's batch (likewise b_scale/b), then both broadcast to the common batch shape. +@generated function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS}, + b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS}, + acc::Tile{Tc, SC}, + ::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc, + SA, SAS, SB, SBS, SC, NA, NB} + sa = Tuple(SA.parameters); sas = Tuple(SAS.parameters) + sb = Tuple(SB.parameters); sbs = Tuple(SBS.parameters) + + # Matrix dims are first two; batch dims are trailing. + M = sa[1]; K = sa[2]; N = sb[2] + Ksa = sas[2] # a_scale K_s (a_scale is (M, K_s, batch...)) + Ksb = sbs[1] # b_scale K_s (b_scale is (K_s, N, batch...)) + a_batch = sa[3:end]; b_batch = sb[3:end] + as_batch = sas[3:end]; bs_batch = sbs[3:end] + + # Broadcast batch dims (pad shorter with trailing 1s, then broadcast). + n_batch = max(length(a_batch), length(b_batch)) + a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...) + b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...) + as_batch_padded = (as_batch..., ntuple(Returns(1), n_batch - length(as_batch))...) + bs_batch_padded = (bs_batch..., ntuple(Returns(1), n_batch - length(bs_batch))...) + batch_shape = map(max, a_batch_padded, b_batch_padded) + B_flat = prod(batch_shape) + + quote + # Reshape + broadcast to align batch dims (still trailing). + a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...))) + b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...))) + as_bc = broadcast_to(reshape(a_scale, $((M, Ksa, as_batch_padded...))), $((M, Ksa, batch_shape...))) + bs_bc = broadcast_to(reshape(b_scale, $((Ksb, N, bs_batch_padded...))), $((Ksb, N, batch_shape...))) + acc_bc = broadcast_to(acc, $((M, N, batch_shape...))) + # Flatten batch dims to one — no permute needed since row-major Tile IR + # already has batch as the leading (slowest) dimension. + a_3d = reshape(a_bc, $((M, K, B_flat))) + b_3d = reshape(b_bc, $((K, N, B_flat))) + as_3d = reshape(as_bc, $((M, Ksa, B_flat))) + bs_3d = reshape(bs_bc, $((Ksb, N, B_flat))) + acc_3d = reshape(acc_bc, $((M, N, B_flat))) + # mmaf_scaled with swapped operands for row-major convention. + result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d) + # Unflatten batch dims. + reshape(result_3d, $((M, N, batch_shape...))) + end +end + # Matrix multiplication: A * B = muladd(A, B, zeros) # Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen # diff --git a/test/device/integration.jl b/test/device/integration.jl index 01a7d85c..fcd50656 100644 --- a/test/device/integration.jl +++ b/test/device/integration.jl @@ -33,3 +33,40 @@ using CUDA @test c_cpu ≈ c_ref end + +@testset "i8/u8 matmul (mmai)" begin + # i8/u8 × i8/u8 → i32 lowers to `cuda_tile.mmai`. The accumulator must be + # Int32 (the `*` operator would pick the input dtype as acc and fail), so we + # use `muladd` with an explicit i32 acc. Per-operand signedness is derived + # from the Julia element type, so we feed values that differ under signed vs + # unsigned interpretation (negatives, and magnitudes > 127) to confirm each + # operand is interpreted correctly. K = 16, |product| ≤ 255² · 16 ≈ 1.0e6, + # well within Int32, so the result is exact. + function mmai(a::ct.TileArray{T1,2}, b::ct.TileArray{T2,2}, c::ct.TileArray{Int32,2}) where {T1,T2} + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + tc = muladd(ta, tb, zeros(Int32, (16, 16))) + ct.store(c, (1, 1), tc) + return + end + + M = K = N = 16 + @testset "signed × signed" begin + a = rand(Int8(-128):Int8(127), M, K); b = rand(Int8(-128):Int8(127), K, N) + c = CUDA.zeros(Int32, M, N) + @cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c) + @test Array(c) == Int32.(a) * Int32.(b) + end + @testset "unsigned × unsigned" begin + a = rand(UInt8(0):UInt8(255), M, K); b = rand(UInt8(0):UInt8(255), K, N) + c = CUDA.zeros(Int32, M, N) + @cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c) + @test Array(c) == Int32.(a) * Int32.(b) + end + @testset "unsigned × signed" begin + a = rand(UInt8(0):UInt8(255), M, K); b = rand(Int8(-128):Int8(127), K, N) + c = CUDA.zeros(Int32, M, N) + @cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c) + @test Array(c) == Int32.(a) * Int32.(b) + end +end diff --git a/test/extensions/DLFP8Types.jl b/test/extensions/DLFP8Types.jl index cb72d48f..f18371e4 100644 --- a/test/extensions/DLFP8Types.jl +++ b/test/extensions/DLFP8Types.jl @@ -2,6 +2,7 @@ using CUDA using DLFP8Types: Float8_E4M3FN, Float8_E5M2 spec1d = ct.ArraySpec{1}(16, true) +spec2d = ct.ArraySpec{2}(16, true) @testset "codegen" begin @@ -31,14 +32,40 @@ end end end +# Non-scaled f8 matmul lowers to `mmaf` (f8 operands, f32 accumulator). +@test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)))) + return + end end -# FP8 types are Blackwell-only -@testset "execution" begin -if capability(device()) >= v"10" +# `fast_acc=true` is an FP8-only hint; it still lowers to `mmaf` (13.3+). +@test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.3") do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end +end + +end + +# Execution kernels are plain top-level functions, each defined next to the +# test that exercises it. Kernels parametric on accumulator dtype must stay at +# top level — defining them inside a testset scope boxes them into closures. -# Round-trip Float32 → FP8 → Float32 on values exactly representable in -# the target FP8 type — result must match input bit-for-bit. +# Round-trip Float32 → FP8 → Float32 on values exactly representable in the +# target FP8 type — result must match input bit-for-bit. function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) pid = ct.bid(1) tile = ct.load(a, pid, (16,)) @@ -51,19 +78,8 @@ function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) return end - -representable = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, - 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] -let a = CuArray(representable), b = CUDA.zeros(Float32, length(representable)) - @cuda backend=cuTile blocks=1 rt_e4m3(a, b) - @test Array(b) == representable - @cuda backend=cuTile blocks=1 rt_e5m2(a, b) - @test Array(b) == representable -end - # FMA in FP8: load Float32, convert to FP8, multiply-add in FP8, convert back. -# Uses inputs whose products and sums also stay representable, so the result -# is exact. +# Inputs whose products and sums also stay representable, so the result is exact. function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) pid = ct.bid(1) @@ -73,6 +89,33 @@ function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc))) return end +# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32). +function mma_dl_fp8(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Tacc,2}, D::ct.TileArray{Float32,2}) where {Tacc<:Union{Float16,Float32}} + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_dl_fast(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c; fast_acc=true))) + return +end + +# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+). +@testset "execution" begin +if capability(device()) >= v"9" + +representable = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, + 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] +let a = CuArray(representable), b = CUDA.zeros(Float32, length(representable)) + @cuda backend=cuTile blocks=1 rt_e4m3(a, b) + @test Array(b) == representable + @cuda backend=cuTile blocks=1 rt_e5m2(a, b) + @test Array(b) == representable +end + let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] @@ -82,5 +125,30 @@ let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, - @test Array(d) == av .* bv .+ cv end +@testset "mma → $Tacc acc" for Tacc in (Float32, Float16) + M = 16 + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + ch = Tacc.(Float32.(rand(0:2, M, M))) + ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 mma_dl_fp8(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test Array(D) == ref +end + +# fast_acc only has an effect on Hopper (sm_90); ignored elsewhere. So off +# Hopper we assert the exact result (the flag must ride through without +# perturbing the output); on Hopper we make no numeric claim. +@testset "mma fast_acc (exact off Hopper)" begin + M = 16 + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + ch = Float32.(rand(0:2, M, M)) + ref = Float32.(ah) * Float32.(bh) .+ ch + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 mma_dl_fast(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test (Array(D) == ref) || (v"9" <= capability(device()) < v"10") +end + end end diff --git a/test/extensions/Microfloats.jl b/test/extensions/Microfloats.jl deleted file mode 100644 index 13b64c0a..00000000 --- a/test/extensions/Microfloats.jl +++ /dev/null @@ -1,229 +0,0 @@ -using CUDA -using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN - -spec1d = ct.ArraySpec{1}(16, true) - -@testset "codegen" begin - -# Float32 -> Float8_E4M3FN (always available; 13.1+) -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - @check "ftof" - converted = convert(ct.Tile{Float8_E4M3FN}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end -end - -# Float32 -> Float8_E5M2 (always available; 13.1+) -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - @check "ftof" - converted = convert(ct.Tile{Float8_E5M2}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end -end - -# Float32 -> Float8_E8M0FNU works on bytecode 13.2+ -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.2") do a, b - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - @check "ftof" - converted = convert(ct.Tile{Float8_E8M0FNU}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end -end - -# Float8_E8M0FNU rejected on bytecode 13.1 with a clear version error -let kernel = (a, b) -> begin - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - converted = convert(ct.Tile{Float8_E8M0FNU}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end - @test_throws "v13.2+" code_tiled(devnull, kernel, - Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.1") -end - -# Float4_E2M1FN requires bytecode 13.3 — rejected at 13.2 with a clear error -let kernel = (a, b) -> begin - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - converted = convert(ct.Tile{Float4_E2M1FN}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end - @test_throws "v13.3+" code_tiled(devnull, kernel, - Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.2") -end - -# Whole-tile `reinterpret` between UInt8 and Float4_E2M1FN packs/unpacks two FP4 -# per byte: a `Tile{UInt8,Tuple{8}}` unpacks to a `Tile{Float4_E2M1FN,Tuple{16}}`, -# lowering to `cuda_tile.unpack` (13.3+). -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.3") do a, b - pid = ct.bid(1) - bytes = ct.load(a, pid, (8,)) # Tile{UInt8,Tuple{8}} - @check "unpack" - fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,Tuple{16}} - ct.store(b, pid, convert(ct.Tile{Float32}, fp4)) - return - end -end - -# And the reverse packs FP4 back into bytes via `cuda_tile.pack` (13.3+). -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}; - bytecode_version=v"13.3") do a, b - pid = ct.bid(1) - vals = ct.load(a, pid, (16,)) - fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) # Tile{Float4_E2M1FN,Tuple{16}} - @check "pack" - ct.store(b, pid, reinterpret(UInt8, fp4)) # Tile{UInt8,Tuple{8}} - return - end -end - -end - -# FP8 types are Blackwell-only -@testset "execution" begin -if capability(device()) >= v"10" - -# Round-trip Float32 → microfloat → Float32 on values exactly representable -# in the target type — result must match input bit-for-bit. -function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E4M3FN}, tile))) - return -end -function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) - return -end -# Float8_E4M3FN / Float8_E5M2: 13.1+, always available -representable8 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, - 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] -let a = CuArray(representable8), b = CUDA.zeros(Float32, length(representable8)) - @cuda backend=cuTile blocks=1 rt_e4m3(a, b) - @test Array(b) == representable8 - @cuda backend=cuTile blocks=1 rt_e5m2(a, b) - @test Array(b) == representable8 -end - -# FMA in FP8: load Float32, convert to FP8, multiply-add in FP8, convert back. -# Uses inputs whose products and sums also stay representable, so the result -# is exact. -function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, - c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) - pid = ct.bid(1) - ta = convert(ct.Tile{Float8_E4M3FN}, ct.load(a, pid, (16,))) - tb = convert(ct.Tile{Float8_E4M3FN}, ct.load(b, pid, (16,))) - tc = convert(ct.Tile{Float8_E4M3FN}, ct.load(c, pid, (16,))) - ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc))) - return -end -let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], - bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], - cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] - a, b, c = CuArray(av), CuArray(bv), CuArray(cv) - d = CUDA.zeros(Float32, length(av)) - @cuda backend=cuTile blocks=1 fma_e4m3(a, b, c, d) - @test Array(d) == av .* bv .+ cv -end - -# Float8_E8M0FNU and Float4_E2M1FN: codegen for `ftof` lowers fine (see the -# codegen tests above) but tileiras refuses to lower a *standalone* f32 ↔ -# microfloat conversion on Blackwell — these formats only have meaningful -# hardware paths as the scale/operand dtypes of block-scaled mma, or packed -# into a byte-wide tile via `reinterpret` (below). - -# Float4_E2M1FN moves through global memory packed two-per-byte: `reinterpret` -# unpacks a `UInt8` tile into FP4 (doubling the leading dim) and packs it back. - -# Pure pack/unpack round-trip: reinterpreting UInt8 → FP4 → UInt8 only -# reinterprets bits, so the bytes must come back unchanged for any input. -function rt_pack(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt8,1}) - pid = ct.bid(1) - bytes = ct.load(a, pid, (8,)) - fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 → (16,) FP4 - ct.store(b, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 → (8,) UInt8 - return -end -let av = UInt8[0x00, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde], - b = CUDA.zeros(UInt8, 8) - a = CuArray(av) - @cuda backend=cuTile blocks=1 rt_pack(a, b) - @test Array(b) == av -end - -# Value round-trip through FP4 stored as UInt8: convert Float32 → FP4, pack to -# UInt8 and store; then load UInt8, unpack to FP4 and convert back to Float32. -# All inputs are exactly representable in E2M1 (magnitudes 0,0.5,1,1.5,2,3,4,6), -# so the result must match bit-for-bit. -function pack_fp4(src::ct.TileArray{Float32,1}, dst::ct.TileArray{UInt8,1}) - pid = ct.bid(1) - vals = ct.load(src, pid, (16,)) - fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) - ct.store(dst, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 → (8,) UInt8 - return -end -function unpack_fp4(src::ct.TileArray{UInt8,1}, dst::ct.TileArray{Float32,1}) - pid = ct.bid(1) - bytes = ct.load(src, pid, (8,)) - fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 → (16,) FP4 - ct.store(dst, pid, convert(ct.Tile{Float32}, fp4)) - return -end -let representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] - src = CuArray(representable4) - packed = CUDA.zeros(UInt8, 8) - out = CUDA.zeros(Float32, 16) - @cuda backend=cuTile blocks=1 pack_fp4(src, packed) - @cuda backend=cuTile blocks=1 unpack_fp4(packed, out) - @test Array(out) == representable4 -end - -# N-D reinterpret: pack/unpack are rank-1, but whole-tile `reinterpret` flattens -# (via reshape) so it works on any rank. Round-trip a 2-D Float32 tile through -# 2-D FP4 and 2-D packed UInt8: the leading (column-major) dim absorbs the 2× / -# ½ scaling — (8,2) FP4 ↔ (4,2) UInt8. -function rt_fp4_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) - pid = ct.bid(1) - fp4 = convert(ct.Tile{Float4_E2M1FN}, ct.load(src, pid, (8, 2))) # (8,2) FP4 - bytes = reinterpret(UInt8, fp4) # (4,2) UInt8 - fp4b = reinterpret(Float4_E2M1FN, bytes) # (8,2) FP4 - ct.store(dst, pid, convert(ct.Tile{Float32}, fp4b)) - return -end -let m = reshape(Float32[0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, - -2.0, -3.0, -4.0, -6.0, 0.5, 1.0, 2.0, 3.0], 8, 2) - src = CuArray(m) - out = CUDA.zeros(Float32, 8, 2) - @cuda backend=cuTile blocks=1 rt_fp4_2d(src, out) - @test Array(out) == m -end - -end -end diff --git a/test/extensions/Microfloats/codegen.jl b/test/extensions/Microfloats/codegen.jl new file mode 100644 index 00000000..efbdf698 --- /dev/null +++ b/test/extensions/Microfloats/codegen.jl @@ -0,0 +1,242 @@ +using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN + +spec1d = ct.ArraySpec{1}(16, true) +spec2d = ct.ArraySpec{2}(16, true) + +@testset "Microfloats codegen" begin + +@testset "ftof" begin + # Float32 -> Float8_E4M3FN (always available; 13.1+) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{Float8_E4M3FN}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Float32 -> Float8_E5M2 (always available; 13.1+) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{Float8_E5M2}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Float32 -> Float8_E8M0FNU works on bytecode 13.2+ + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.2") do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{Float8_E8M0FNU}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Float8_E8M0FNU rejected on bytecode 13.1 with a clear version error + let kernel = (a, b) -> begin + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + converted = convert(ct.Tile{Float8_E8M0FNU}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + @test_throws "v13.2+" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.1") + end + + # Float4_E2M1FN requires bytecode 13.3 — rejected at 13.2 with a clear error + let kernel = (a, b) -> begin + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + converted = convert(ct.Tile{Float4_E2M1FN}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + @test_throws "v13.3+" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.2") + end +end + +@testset "reinterpret" begin + # Whole-tile `reinterpret` between UInt8 and Float4_E2M1FN packs/unpacks two + # FP4 per byte: a `Tile{UInt8,(8,)}` unpacks to a `Tile{Float4_E2M1FN,(16,)}`, + # lowering to `cuda_tile.unpack` (13.3+). + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.3") do a, b + pid = ct.bid(1) + bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)} + @check "unpack" + fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)} + ct.store(b, pid, convert(ct.Tile{Float32}, fp4)) + return + end + end + + # And the reverse packs FP4 back into bytes via `cuda_tile.pack` (13.3+). + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}; + bytecode_version=v"13.3") do a, b + pid = ct.bid(1) + vals = ct.load(a, pid, (16,)) + fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) # Tile{Float4_E2M1FN,(16,)} + @check "pack" + ct.store(b, pid, reinterpret(UInt8, fp4)) # Tile{UInt8,(8,)} + return + end + end +end + +@testset "mma" begin + # f8e4m3fn operands with an f32 accumulator lower to `mmaf` (13.1+). + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)))) + return + end + end + + # f8e5m2 operands with an f16 accumulator also lower to `mmaf`. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E5M2,2,spec2d}, ct.TileArray{Float8_E5M2,2,spec2d}, + ct.TileArray{Float16,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float16, (16, 16)))) + return + end + end + + # A disallowed accumulator dtype for f8 (only f16/f32 are valid) is rejected + # with a clear error rather than producing an mmaf op tileiras would reject. + @test_throws "tileiras requires acc" code_tiled( + Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float64,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + acc = zeros(Float64, (16, 16)) + Base.donotdelete(ct.Intrinsics.mma(ta, tb, acc)) + return + end +end + +@testset "fast_acc" begin + # `fast_acc=true` on f8 operands still lowers to `mmaf` (the hint rides on + # the op as a flag); requires bytecode 13.3. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; + bytecode_version=v"13.3") do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end + end + + # `fast_acc` is an FP8-only hint: requesting it for f16 inputs is rejected. + @test_throws "only supported for fp8" code_tiled( + Tuple{ct.TileArray{Float16,2,spec2d}, ct.TileArray{Float16,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.3") do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end + + # `fast_acc` requires bytecode 13.3 — rejected at 13.2 with a clear error. + let kernel = (a, b, c) -> begin + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end + @test_throws "13.3" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.2") + end +end + +@testset "mma_scaled" begin + # MXFP8: f8 operands (M,K)/(K,N) with f8e8m0fnu block scales (M,K_s)/(K_s,N) + # accumulate into f32. Block size B = K ÷ K_s = 64 ÷ 2 = 32. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; + bytecode_version=v"13.3") do x, x_scale, y, y_scale, z + xt = ct.load(x, (1, 1), (16, 64)) + xst = ct.load(x_scale, (1, 1), (16, 2)) + yt = ct.load(y, (1, 1), (64, 16)) + yst = ct.load(y_scale, (1, 1), (2, 16)) + @check "mmaf_scaled" + result = ct.muladd_scaled(xt, xst, yt, yst, zeros(Float32, (16, 16))) + ct.store(z, (1, 1), result) + return + end + end + + # NVFP4: f4e2m1fn operands with f8e4m3fn scales, B = 16. Operands enter as + # unpacked FP4 tiles; `mmaf_scaled` accepts them too. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; + bytecode_version=v"13.3") do x, x_scale, y, y_scale, z + xt = convert(ct.Tile{Float4_E2M1FN}, ct.load(x, (1, 1), (16, 64))) + xst = ct.load(x_scale, (1, 1), (16, 4)) + yt = convert(ct.Tile{Float4_E2M1FN}, ct.load(y, (1, 1), (64, 16))) + yst = ct.load(y_scale, (1, 1), (4, 16)) + @check "mmaf_scaled" + result = ct.muladd_scaled(xt, xst, yt, yst, zeros(Float32, (16, 16))) + ct.store(z, (1, 1), result) + return + end + end + + # mma_scaled requires bytecode 13.3 — rejected at 13.2 with a clear error. + let kernel = (x, x_scale, y, y_scale, z) -> begin + xt = ct.load(x, (1, 1), (16, 64)) + xst = ct.load(x_scale, (1, 1), (16, 2)) + yt = ct.load(y, (1, 1), (64, 16)) + yst = ct.load(y_scale, (1, 1), (2, 16)) + ct.store(z, (1, 1), ct.muladd_scaled(xt, xst, yt, yst, zeros(Float32, (16, 16)))) + return + end + @test_throws "13.3" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.2") + end +end + +end diff --git a/test/extensions/Microfloats/device.jl b/test/extensions/Microfloats/device.jl new file mode 100644 index 00000000..97ede9a5 --- /dev/null +++ b/test/extensions/Microfloats/device.jl @@ -0,0 +1,370 @@ +using CUDA +using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN + +# Kernels are plain top-level functions (not closures), each defined next to the +# testset that exercises it. Kernels parametric on operand dtype must stay at +# top level — defining them inside a testset scope boxes them into closures. + +# ftof round-trips +function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E4M3FN}, tile))) + return +end +function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) + return +end +function rt_e8m0(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E8M0FNU}, tile))) + return +end +function rt_f4(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float4_E2M1FN}, tile))) + return +end +function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) + pid = ct.bid(1) + ta = convert(ct.Tile{Float8_E4M3FN}, ct.load(a, pid, (16,))) + tb = convert(ct.Tile{Float8_E4M3FN}, ct.load(b, pid, (16,))) + tc = convert(ct.Tile{Float8_E4M3FN}, ct.load(c, pid, (16,))) + ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc))) + return +end + +# Standalone f32 → microfloat → f32 conversion round-trips exactly for every +# microfloat type on representable inputs. FP8 (e4m3/e5m2) needs Hopper (sm_90+); +# E8M0FNU and Float4_E2M1FN need Blackwell (sm_100+). E8M0 is exponent-only, so +# its representable values are exact powers of two. +@testset "ftof" begin +if capability(device()) >= v"9" + # Round-trip Float32 → microfloat → Float32 on exactly-representable values. + representable8 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, + 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] + let a = CuArray(representable8), b = CUDA.zeros(Float32, length(representable8)) + @cuda backend=cuTile blocks=1 rt_e4m3(a, b) + @test Array(b) == representable8 + @cuda backend=cuTile blocks=1 rt_e5m2(a, b) + @test Array(b) == representable8 + end + + # FMA in FP8: load f32, convert to FP8, multiply-add in FP8, convert back. + # Inputs whose products and sums stay representable, so the result is exact. + let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], + bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], + cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] + a, b, c = CuArray(av), CuArray(bv), CuArray(cv) + d = CUDA.zeros(Float32, length(av)) + @cuda backend=cuTile blocks=1 fma_e4m3(a, b, c, d) + @test Array(d) == av .* bv .+ cv + end +end +if capability(device()) >= v"10" + # E8M0 round-trip: exponent-only, so representable values are powers of two. + representable_e8m0 = Float32(2) .^ Float32[-8, -6, -4, -2, -1, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 10, 16] + let a = CuArray(representable_e8m0), b = CUDA.zeros(Float32, length(representable_e8m0)) + @cuda backend=cuTile blocks=1 rt_e8m0(a, b) + @test Array(b) == representable_e8m0 + end + + # F4 round-trip via the standalone conversion path (no pack/reinterpret). + representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] + let a = CuArray(representable4), b = CUDA.zeros(Float32, length(representable4)) + @cuda backend=cuTile blocks=1 rt_f4(a, b) + @test Array(b) == representable4 + end +end +end + +# reinterpret (pack / unpack) +function rt_pack(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt8,1}) + pid = ct.bid(1) + bytes = ct.load(a, pid, (8,)) + fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 → (16,) FP4 + ct.store(b, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 → (8,) UInt8 + return +end +function pack_fp4(src::ct.TileArray{Float32,1}, dst::ct.TileArray{UInt8,1}) + pid = ct.bid(1) + vals = ct.load(src, pid, (16,)) + fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) + ct.store(dst, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 → (8,) UInt8 + return +end +function unpack_fp4(src::ct.TileArray{UInt8,1}, dst::ct.TileArray{Float32,1}) + pid = ct.bid(1) + bytes = ct.load(src, pid, (8,)) + fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 → (16,) FP4 + ct.store(dst, pid, convert(ct.Tile{Float32}, fp4)) + return +end +function rt_fp4_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) + pid = ct.bid(1) + fp4 = convert(ct.Tile{Float4_E2M1FN}, ct.load(src, pid, (8, 2))) # (8,2) FP4 + bytes = reinterpret(UInt8, fp4) # (4,2) UInt8 + fp4b = reinterpret(Float4_E2M1FN, bytes) # (8,2) FP4 + ct.store(dst, pid, convert(ct.Tile{Float32}, fp4b)) + return +end + +# Float4_E2M1FN requires Blackwell (sm_100+). +@testset "reinterpret" begin +if capability(device()) >= v"10" + # Pure pack/unpack round-trip: UInt8 → FP4 → UInt8 must be a no-op. + let av = UInt8[0x00, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde], + b = CUDA.zeros(UInt8, 8) + a = CuArray(av) + @cuda backend=cuTile blocks=1 rt_pack(a, b) + @test Array(b) == av + end + + # Value round-trip through FP4 stored as UInt8 (all inputs representable). + let representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] + src = CuArray(representable4) + packed = CUDA.zeros(UInt8, 8) + out = CUDA.zeros(Float32, 16) + @cuda backend=cuTile blocks=1 pack_fp4(src, packed) + @cuda backend=cuTile blocks=1 unpack_fp4(packed, out) + @test Array(out) == representable4 + end + + # N-D reinterpret: whole-tile `reinterpret` flattens, so it works on any + # rank. (8,2) FP4 ↔ (4,2) UInt8 — the leading dim absorbs the 2× / ½. + let m = reshape(Float32[0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, + -2.0, -3.0, -4.0, -6.0, 0.5, 1.0, 2.0, 3.0], 8, 2) + src = CuArray(m) + out = CUDA.zeros(Float32, 8, 2) + @cuda backend=cuTile blocks=1 rt_fp4_2d(src, out) + @test Array(out) == m + end +end +end + +# non-scaled FP8 matmul (f8 × f8 with f16 or f32 accumulator) +function mma_fp8(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, + C::ct.TileArray{Tacc,2}, D::ct.TileArray{Float32,2} + ) where {T<:Union{Float8_E4M3FN,Float8_E5M2},Tacc<:Union{Float16,Float32}} + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_e4m3_star(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, a * b)) + return +end +function mma_e4m3_fast(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c; fast_acc=true))) + return +end + +# Non-scaled FP8 matmul needs only Hopper (sm_90+); block-scaled mma (below) is +# the Blackwell-only variant. +@testset "mma" begin +if capability(device()) >= v"9" + M, K, N = 16, 16, 16 + # f8 operands with both allowed accumulator dtypes (f16 and f32). Inputs are + # in {0, 0.5, 1} and K = 16, so every partial sum is exactly representable + # in f16 too — the f16-acc result matches the f32 reference bit-for-bit. + @testset "$T × $T → $Tacc acc" for (T, Tacc) in ( + (Float8_E4M3FN, Float32), (Float8_E4M3FN, Float16), + (Float8_E5M2, Float32), (Float8_E5M2, Float16)) + ah = T.(Float32.(rand(0:2, M, K)) ./ 2) + bh = T.(Float32.(rand(0:2, K, N)) ./ 2) + ch = Tacc.(Float32.(rand(0:2, M, N))) + ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) + D = CUDA.zeros(Float32, M, N) + @cuda backend=cuTile blocks=1 mma_fp8(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test Array(D) == ref + end + + @testset "* operator (auto acc → f8)" begin + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + # f8 × f8 accumulates in f16, then downcasts back to f8. + ref = Float32.(Float8_E4M3FN.(Float16.(Float32.(ah) * Float32.(bh)))) + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 mma_e4m3_star(CuArray(ah), CuArray(bh), D) + @test Array(D) == ref + end + + # fast_acc trades accumulator precision for throughput, but only has an + # effect on Hopper (sm_90); it is silently ignored on every other arch. So + # off Hopper we assert the exact result — the flag must ride through tileiras + # and ptxas without perturbing the output. On Hopper we make no numeric + # claim: fast accumulation may legitimately diverge there, and we have no + # fast-accum reference to compare against. + @testset "fast_acc (exact off Hopper)" begin + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, K)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, K, N)) ./ 2) + ch = Float32.(rand(0:2, M, N)) + ref = Float32.(ah) * Float32.(bh) .+ ch + D = CUDA.zeros(Float32, M, N) + @cuda backend=cuTile blocks=1 mma_e4m3_fast(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test (Array(D) == ref) || (v"9" <= capability(device()) < v"10") + end +end +end + +# block-scaled mma (B = K ÷ K_s; here K = 64, K_s = 2 → B = 32) +function mma_scaled_2d(X::ct.TileArray{T,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{T,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) where {T<:Union{Float8_E4M3FN,Float8_E5M2}} + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, (1, 1), (64, 16)); ys = ct.load(YS, (1, 1), (2, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end +function mma_scaled_matvec(X::ct.TileArray{Float8_E4M3FN,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{Float8_E4M3FN,1}, YS::ct.TileArray{Float8_E8M0FNU,1}, + Z::ct.TileArray{Float32,1}) + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, 1, (64,)); ys = ct.load(YS, 1, (2,)) + ct.store(Z, 1, ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16,)))) + return +end +function mma_scaled_batched(X::ct.TileArray{Float8_E4M3FN,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{Float8_E4M3FN,3}, YS::ct.TileArray{Float8_E8M0FNU,3}, + Z::ct.TileArray{Float32,3}) + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, (1, 1, 1), (64, 16, 2)); ys = ct.load(YS, (1, 1, 1), (2, 16, 2)) + ct.store(Z, (1, 1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16, 2)))) + return +end +# FP4 operands have no direct sub-byte load: they arrive packed two-per-byte +# along the matmul-contiguous axis (K for X, N for Y, matching the row-major +# reference). cuTile's `reinterpret` doubles the *leading* (column-major) dim, +# so we load each operand with its packed axis leading, unpack, then transpose +# into the (M,K) / (K,N) orientation `muladd_scaled` expects: +# X: bytes (K/2, M) → reinterpret (K, M) → transpose (M, K) +# Y: bytes (N/2, K) → reinterpret (N, K) → transpose (K, N) +# MXFP4 takes an f8e8m0fnu scale (B = 32 → K_s = 2); NVFP4 an f8e4m3fn scale +# (B = 16 → K_s = 4). Operand unpacking is identical — only the scale differs. +function mma_scaled_mxfp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{UInt8,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) + x = permutedims(reinterpret(Float4_E2M1FN, ct.load(X, (1, 1), (32, 16))), (2, 1)) # (M,K) + y = permutedims(reinterpret(Float4_E2M1FN, ct.load(Y, (1, 1), (8, 64))), (2, 1)) # (K,N) + xs = ct.load(XS, (1, 1), (16, 2)) + ys = ct.load(YS, (1, 1), (2, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end +function mma_scaled_nvfp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E4M3FN,2}, + Y::ct.TileArray{UInt8,2}, YS::ct.TileArray{Float8_E4M3FN,2}, + Z::ct.TileArray{Float32,2}) + x = permutedims(reinterpret(Float4_E2M1FN, ct.load(X, (1, 1), (32, 16))), (2, 1)) # (M,K) + y = permutedims(reinterpret(Float4_E2M1FN, ct.load(Y, (1, 1), (8, 64))), (2, 1)) # (K,N) + xs = ct.load(XS, (1, 1), (16, 4)) + ys = ct.load(YS, (1, 1), (4, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end + +# Expand a block-scale tile to per-element scales along the K dimension: each +# scale entry covers B contiguous K positions (matches Tile IR's broadcast). +expand_k(scale, B, ::Val{:cols}) = repeat(Float32.(scale), inner=(1, B)) # (M, K_s) → (M, K) +expand_k(scale, B, ::Val{:rows}) = repeat(Float32.(scale), inner=(B, 1)) # (K_s, N) → (K, N) + +# Reference: (x ⊙ x_scale) * (y ⊙ y_scale) + acc, with acc = 0. +function mma_scaled_ref(xh, xs, yh, ys, B) + (Float32.(xh) .* expand_k(xs, B, Val(:cols))) * + (Float32.(yh) .* expand_k(ys, B, Val(:rows))) +end + +# Low 4 bits of an FP4 value's byte storage (its E2M1 nibble). +f4_nibble(v) = reinterpret(UInt8, Float4_E2M1FN(v)) & 0x0f + +# Pack an (R, C) FP4 matrix two-per-byte along the contiguous (second) axis → +# (C/2, R) bytes: low nibble = even column, high nibble = odd column. Matches +# X (M,K)→(K/2,M) and Y (K,N)→(N/2,K) in the kernels above. +pack_f4_bytes(vals) = UInt8[f4_nibble(vals[r, 2c-1]) | (f4_nibble(vals[r, 2c]) << 4) + for c in 1:(size(vals, 2) ÷ 2), r in 1:size(vals, 1)] + +# All FP4-representable magnitudes, for sampling exact inputs. +const F4_VALUES = Float32[0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.5, -1, -1.5, -2, -3, -4, -6] + +# Block-scaled mma is Blackwell-only (sm_100+), for both FP8 and FP4 operands. +@testset "mma_scaled" begin +if capability(device()) >= v"10" + m, n, k, ks, B = 16, 16, 64, 2, 32 + + # MXFP8: f8 operands with an f8e8m0fnu scale, B = 32. 2D × 2D, each f8 type. + @testset "MXFP8 2D ($T)" for T in (Float8_E4M3FN, Float8_E5M2) + xh = T.(Float32.(rand(0:2, m, k)) ./ 2) # values in {0, 0.5, 1} + yh = T.(Float32.(rand(0:2, k, n)) ./ 2) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) # scales in {1, 2} + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n)) + ref = mma_scaled_ref(xh, xs, yh, ys, B) + Z = CUDA.zeros(Float32, m, n) + @cuda backend=cuTile blocks=1 mma_scaled_2d(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @test Array(Z) == ref + end + + @testset "MXFP8 mat-vec" begin + xh = Float8_E4M3FN.(Float32.(rand(0:2, m, k)) ./ 2) + yh = Float8_E4M3FN.(Float32.(rand(0:2, k)) ./ 2) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks)) + ref = (Float32.(xh) .* expand_k(xs, B, Val(:cols))) * + (Float32.(yh) .* repeat(Float32.(ys), inner=(B,))) + Z = CUDA.zeros(Float32, m) + @cuda backend=cuTile blocks=1 mma_scaled_matvec(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @test Array(Z) == ref + end + + @testset "MXFP8 batched (trailing batch broadcast)" begin + bt = 2 + xh = Float8_E4M3FN.(Float32.(rand(0:2, m, k)) ./ 2) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) + yh = Float8_E4M3FN.(Float32.(rand(0:2, k, n, bt)) ./ 2) + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n, bt)) + ref = stack(mma_scaled_ref(xh, xs, yh[:, :, bi], ys[:, :, bi], B) for bi in 1:bt) + Z = CUDA.zeros(Float32, m, n, bt) + @cuda backend=cuTile blocks=1 mma_scaled_batched(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @test Array(Z) == ref + end + + # MXFP4: FP4 operands packed two-per-byte (unpacked + transposed in the + # kernel) with an f8e8m0fnu scale, B = 32 → K_s = 2. + @testset "MXFP4" begin + xh = [rand(F4_VALUES) for _ in 1:m, _ in 1:k] # (M, K) + yh = [rand(F4_VALUES) for _ in 1:k, _ in 1:n] # (K, N) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n)) + ref = mma_scaled_ref(xh, xs, yh, ys, B) + Z = CUDA.zeros(Float32, m, n) + @cuda backend=cuTile blocks=1 mma_scaled_mxfp4(CuArray(pack_f4_bytes(xh)), CuArray(xs), + CuArray(pack_f4_bytes(yh)), CuArray(ys), Z) + @test Array(Z) == ref + end + + # NVFP4: same FP4 operands but an f8e4m3fn scale at B = 16 → K_s = 4. + @testset "NVFP4" begin + Bn, ksn = 16, k ÷ 16 # B = 16 → K_s = 4 + xh = [rand(F4_VALUES) for _ in 1:m, _ in 1:k] # (M, K) + yh = [rand(F4_VALUES) for _ in 1:k, _ in 1:n] # (K, N) + xs = Float8_E4M3FN.(Float32(2) .^ rand(0:1, m, ksn)) # scales in {1, 2}, exact in e4m3 + ys = Float8_E4M3FN.(Float32(2) .^ rand(0:1, ksn, n)) + ref = mma_scaled_ref(xh, xs, yh, ys, Bn) + Z = CUDA.zeros(Float32, m, n) + @cuda backend=cuTile blocks=1 mma_scaled_nvfp4(CuArray(pack_f4_bytes(xh)), CuArray(xs), + CuArray(pack_f4_bytes(yh)), CuArray(ys), Z) + @test Array(Z) == ref + end +end +end