From 53980901b4a439a9ebb98d510a029b102c912e0a Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 29 May 2026 02:11:54 +0200 Subject: [PATCH 1/6] Add `reinterpret` as interface to `bitcast`, `pack`, and `unpack` --- src/compiler/intrinsics/conversions.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compiler/intrinsics/conversions.jl b/src/compiler/intrinsics/conversions.jl index 7ec78af0..6d162b37 100644 --- a/src/compiler/intrinsics/conversions.jl +++ b/src/compiler/intrinsics/conversions.jl @@ -74,7 +74,8 @@ function tfunc(๐•ƒ, ::typeof(Intrinsics.pack), @nospecialize(x)) length(dims) == 1 || return nothing n = dims[1]::Int bs = lookup_bitwidth(S) - return Tile{UInt8, Tuple{fld(n * bs, 8)}} + (n * bs) % 8 == 0 || return nothing + return Tile{UInt8, Tuple{(n * bs) รท 8}} end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.pack), args) cb = ctx.cb @@ -122,7 +123,8 @@ function tfunc(๐•ƒ, ::typeof(Intrinsics.unpack), @nospecialize(x), @nospecializ length(dims) == 1 || return nothing n = dims[1]::Int bt = lookup_bitwidth(T) - return Tile{T, Tuple{fld(n * 8, bt)}} + (n * 8) % bt == 0 || return nothing + return Tile{T, Tuple{(n * 8) รท bt}} end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args) cb = ctx.cb From 3ef82a8029c44638652d1cee79cacb0d3cb088ac Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 29 May 2026 02:46:09 +0200 Subject: [PATCH 2/6] reinterpret: centralize validation in emit so invalid casts fail cleanly The shape helpers and pack/unpack tfuncs ran inside the kernel-inferred path, where two failure modes produced confusing errors: - A tfunc returning `nothing` on an indivisible width left the result untypable, surfacing downstream as `internal error: invalid terminators`. - A `throw(ArgumentError(...))` in a shape helper became an unsupported `String` in kernel IR (`format_string`/`unsupported String` error), masking the intended message. Make both layers total: pack/unpack tfuncs always return a concrete type (via `fld`), and the shape helpers are pure arithmetic. Validation now lives solely in the pack/unpack/reshape emit, which throws a clear `IRError` (e.g. "unpack: 1 bytes do not evenly divide into Float32"). Valid reinterprets are unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/compiler/intrinsics/conversions.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/compiler/intrinsics/conversions.jl b/src/compiler/intrinsics/conversions.jl index 6d162b37..7ec78af0 100644 --- a/src/compiler/intrinsics/conversions.jl +++ b/src/compiler/intrinsics/conversions.jl @@ -74,8 +74,7 @@ function tfunc(๐•ƒ, ::typeof(Intrinsics.pack), @nospecialize(x)) length(dims) == 1 || return nothing n = dims[1]::Int bs = lookup_bitwidth(S) - (n * bs) % 8 == 0 || return nothing - return Tile{UInt8, Tuple{(n * bs) รท 8}} + return Tile{UInt8, Tuple{fld(n * bs, 8)}} end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.pack), args) cb = ctx.cb @@ -123,8 +122,7 @@ function tfunc(๐•ƒ, ::typeof(Intrinsics.unpack), @nospecialize(x), @nospecializ length(dims) == 1 || return nothing n = dims[1]::Int bt = lookup_bitwidth(T) - (n * 8) % bt == 0 || return nothing - return Tile{T, Tuple{(n * 8) รท bt}} + return Tile{T, Tuple{fld(n * 8, bt)}} end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.unpack), args) cb = ctx.cb From 2407be5104bfbdd73214e743957aaab24652f505 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 29 May 2026 11:39:48 +0200 Subject: [PATCH 3/6] Add `muladd_scaled` --- ext/DLFP8TypesExt.jl | 6 + ext/MicrofloatsExt.jl | 9 + src/bytecode/encodings.jl | 26 ++ src/compiler/intrinsics/core.jl | 56 ++++- src/language/operations.jl | 128 +++++++++- test/extensions/DLFP8Types.jl | 42 +++- test/extensions/Microfloats.jl | 229 ------------------ test/extensions/Microfloats/codegen.jl | 203 ++++++++++++++++ test/extensions/Microfloats/device.jl | 323 +++++++++++++++++++++++++ 9 files changed, 787 insertions(+), 235 deletions(-) delete mode 100644 test/extensions/Microfloats.jl create mode 100644 test/extensions/Microfloats/codegen.jl create mode 100644 test/extensions/Microfloats/device.jl diff --git a/ext/DLFP8TypesExt.jl b/ext/DLFP8TypesExt.jl index bd226633..462636b5 100644 --- a/ext/DLFP8TypesExt.jl +++ b/ext/DLFP8TypesExt.jl @@ -12,6 +12,12 @@ function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2}) return ct.F8E5M2(table) end +# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2 +# operands with an f16 or f32 accumulator (f16 first/preferred), mirroring +# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. +ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) +ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) + # Float โ†” FP8 scalar constructor overlays (for map/convert dispatch) const FP8Types = (Float8_E4M3FN, Float8_E5M2) const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64) diff --git a/ext/MicrofloatsExt.jl b/ext/MicrofloatsExt.jl index 230cc7f9..25f131aa 100644 --- a/ext/MicrofloatsExt.jl +++ b/ext/MicrofloatsExt.jl @@ -23,6 +23,15 @@ ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth( # nearest-even on f32โ†’E8M0FNU (only `zero` and `positive_inf` are valid). ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero +# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2 +# operands with an f16 or f32 accumulator โ€” f16 first/preferred, mirroring +# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. The +# other Microfloats formats (E8M0FNU, Float4_E2M1FN) are only valid as +# scaled-mma operands/scales, so they stay absent and fall through to `mma`'s +# unsupported-dtype error. +ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) +ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) + # Float โ†” microfloat scalar constructor overlays (for map/convert dispatch). # Mirrors DLFP8TypesExt: route to `Intrinsics.ftof` so kernel-side conversions # lower to the FToFOp Tile IR intrinsic instead of Microfloats' Float32-fallback diff --git a/src/bytecode/encodings.jl b/src/bytecode/encodings.jl index 81946912..fddf37cc 100644 --- a/src/bytecode/encodings.jl +++ b/src/bytecode/encodings.jl @@ -98,6 +98,8 @@ module Opcode const Atan2Op = 110 # since 13.2 const PackOp = 111 # since 13.3 const UnpackOp = 112 # since 13.3 + # 113 (AllocaOp) not implemented + const MmaFScaledOp = 114 # since 13.3 end # Enums for operation attributes @@ -1310,6 +1312,30 @@ function encode_MmaIOp!(cb::CodeBuilder, result_type::TypeId, return new_op!(cb) end +""" + encode_MmaFScaledOp!(cb, result_type, lhs, rhs, acc, lhs_scale, rhs_scale) -> Value + +Block-scaled float matrix multiply-accumulate (`acc + (lhs โŠ™ lhs_scale) @ +(rhs โŠ™ rhs_scale)`) for low-precision (f8/f4) inputs. Each scale element scales +a contiguous block of K-dimension elements in its operand. Requires Tile IR +v13.3+. +Opcode: 114 +""" +function encode_MmaFScaledOp!(cb::CodeBuilder, result_type::TypeId, + lhs::Value, rhs::Value, acc::Value, + lhs_scale::Value, rhs_scale::Value) + cb.version >= v"13.3" || + throw(IRError("MmaFScaledOp requires Tile IR v13.3+, got v$(cb.version)")) + encode_varint!(cb.buf, Opcode.MmaFScaledOp) + encode_typeid!(cb.buf, result_type) + encode_operand!(cb.buf, lhs) + encode_operand!(cb.buf, rhs) + encode_operand!(cb.buf, acc) + encode_operand!(cb.buf, lhs_scale) + encode_operand!(cb.buf, rhs_scale) + return new_op!(cb) +end + #============================================================================= Integer arithmetic operations =============================================================================# diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 6ebc6058..940c08f4 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -458,7 +458,10 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) # the table in cuTile Python's mma implementation. lhs_elem === rhs_elem || throw(IRError("mma: float lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem")) - allowed_acc = mma_allowed_acc_dtypes(lhs_elem) + # `invokelatest` so extension-defined acc-dtype tables (e.g. FP8 via the + # Microfloats/DLFP8Types exts) are visible from codegen's world age, + # mirroring `lookup_bitwidth`. + allowed_acc = Base.invokelatest(mma_allowed_acc_dtypes, lhs_elem) allowed_acc === nothing && throw(IRError("mma: unsupported float input dtype $lhs_elem")) acc_elem in allowed_acc || @@ -480,6 +483,57 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) CGVal(result, acc.type_id, acc.jltype, acc.shape) end +""" + Intrinsics.mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc) -> typeof(acc) + +Block-scaled matrix-multiply-accumulate computing `(lhs โŠ™ lhs_scale) * (rhs โŠ™ +rhs_scale) + acc`, where each scale element multiplies a contiguous block of +`lhs`/`rhs` elements along the K dimension. Lowers to `cuda_tile.mmaf_scaled` +(Tile IR v13.3+). + +`lhs`/`rhs` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`), +`lhs_scale`/`rhs_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc`/result are +`f32`. The block size `K รท K_s` and the exact (operand, scale) dtype pairing are +validated by tileiras (see its `mmaf_scaled` verifier for the supported table). +""" +@intrinsic mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc) +tfunc(๐•ƒ, ::typeof(Intrinsics.mma_scaled), @nospecialize(lhs), @nospecialize(lhs_scale), + @nospecialize(rhs), @nospecialize(rhs_scale), @nospecialize(acc)) = CC.widenconst(acc) +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma_scaled), args) + cb = ctx.cb + + lhs = emit_value!(ctx, args[1]) + lhs_scale = emit_value!(ctx, args[2]) + rhs = emit_value!(ctx, args[3]) + rhs_scale = emit_value!(ctx, args[4]) + acc = emit_value!(ctx, args[5]) + + (lhs === nothing || lhs_scale === nothing || rhs === nothing || + rhs_scale === nothing || acc === nothing) && + throw(IRError("Cannot resolve operands for mma_scaled()")) + + lhs_elem = eltype(CC.widenconst(lhs.jltype)) + rhs_elem = eltype(CC.widenconst(rhs.jltype)) + lhs_scale_elem = eltype(CC.widenconst(lhs_scale.jltype)) + rhs_scale_elem = eltype(CC.widenconst(rhs_scale.jltype)) + acc_elem = eltype(CC.widenconst(acc.jltype)) + + # Structural invariants (cheap, with clear errors); the operand/scale dtype + # pairing and block-size constraints are checked by the tileiras verifier, + # whose messages are already precise. + lhs_elem === rhs_elem || + throw(IRError("mma_scaled: lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem")) + lhs_scale_elem === rhs_scale_elem || + throw(IRError("mma_scaled: lhs_scale and rhs_scale must share dtype, " * + "got $lhs_scale_elem and $rhs_scale_elem")) + acc_elem === Float32 || + throw(IRError("mma_scaled: acc must be Float32, got $acc_elem")) + + result = encode_MmaFScaledOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v, + lhs_scale.v, rhs_scale.v) + CGVal(result, acc.type_id, acc.jltype, acc.shape) +end + # TODO: cuda_tile.module """ diff --git a/src/language/operations.jl b/src/language/operations.jl index db82f526..e89ac245 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -1295,10 +1295,21 @@ end Matrix multiplication =============================================================================# -# Matrix multiply-accumulate: muladd(a, b, acc) = a * b + acc -# Handles 1D promotion, type promotion, and batched dims (โ‰ฅ3D). -# Note: SA, SB, SC type parameters required to avoid ambiguity with scalar methods during codegen +""" + muladd(a::Tile, b::Tile, acc::Tile) -> Tile + +Matrix multiply-accumulate `a * b + acc` over tiles, lowering to +`cuda_tile.mmaf` (float) or `cuda_tile.mmai` (`i8 ร— i8 โ†’ i32`). + +`a`/`b` are 2-D matrices `(M, K)` ร— `(K, N)` โ†’ `(M, N)`; a 1-D operand is +promoted (vec-mat / mat-vec) and any trailing dimensions (โ‰ฅ3-D) are treated as +broadcast batch dims, lifting `Base.muladd`'s shape rules to tiles. `acc` +carries the result dtype, which must be one tileiras allows for the input dtype +(f16/f32 for f16 and f8; f32 for bf16/tf32; f64 for f64; i32 for i8). +""" @inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} + # SA, SB, SC type parameters avoid ambiguity with the scalar `muladd` + # methods during codegen. _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b))) end @@ -1379,6 +1390,117 @@ end end end +#============================================================================= + Block-scaled matrix multiply-accumulate +=============================================================================# + +""" + muladd_scaled(a, a_scale, b, b_scale, acc) -> Tile + +Block-scaled matrix multiply-accumulate `(a โŠ™ a_scale) * (b โŠ™ b_scale) + acc`, +lowering to `cuda_tile.mmaf_scaled` (Tile IR v13.3+, Blackwell). Each scale +element multiplies a contiguous block of `B = K รท K_s` elements along the K +dimension of its operand, so `a_scale`/`b_scale` match `a`/`b` in every +dimension except K, where they have `K_s โ‰ค K` entries. + +`a`/`b` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`), +`a_scale`/`b_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc` is `f32`. Shapes +follow [`muladd`](@ref): 2-D `(M, K)` ร— `(K, N)`, mat-vec, and trailing batch +dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale). +""" +@inline function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile, + acc::Tile) where {Ta, Tb, SA, SB} + _muladd_scaled(a, a_scale, b, b_scale, acc, Val(ndims(a)), Val(ndims(b))) +end + +# 2D ร— 2D: swap operands (and their scales) for row-major Tile IR, exactly as +# `_muladd` swaps for `mma`. +@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, + ::Val{2}, ::Val{2}) + Intrinsics.mma_scaled(b, b_scale, a, a_scale, acc) +end + +# Mat-vec (2D ร— 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim; +# `acc` becomes (M, 1); then squeeze back to (M,). K โ€” the scaled dimension โ€” +# is preserved, so block scaling is well-defined. +@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, + ::Val{2}, ::Val{1}) + M, K, Ks = size(a, 1), size(b, 1), size(b_scale, 1) + b2d = reshape(b, (K, 1)) + b_scale2d = reshape(b_scale, (Ks, 1)) + acc2d = reshape(acc, (M, 1)) + result = _muladd_scaled(a, a_scale, b2d, b_scale2d, acc2d, Val(2), Val(2)) + reshape(result, (M,)) +end + +# Vec-mat (1D ร— 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K +# dimension to block-scale. Unsupported โ€” reshape to 2D and supply a matching +# K_s scale instead. +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2}) + return :(throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)."))) +end + +# Vec-vec (1D ร— 1D): not supported. +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) + return :(throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported."))) +end + +# Batched mat-vec / vec-mat (โ‰ฅ3D ร— 1D or 1D ร— โ‰ฅ3D): not supported. +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} + return :(throw(ArgumentError("Batched scaled vec-mat is not supported."))) +end +@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} + return :(throw(ArgumentError("Batched scaled mat-vec is not supported."))) +end + +# Batched (โ‰ฅ3D ร— โ‰ฅ3D): trailing batch dims with broadcast, mirroring `_muladd`. +# Scales carry the same batch dims as their operands; a_scale's batch must match +# a's batch (likewise b_scale/b), then both broadcast to the common batch shape. +@generated function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS}, + b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS}, + acc::Tile{Tc, SC}, + ::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc, + SA, SAS, SB, SBS, SC, NA, NB} + sa = Tuple(SA.parameters); sas = Tuple(SAS.parameters) + sb = Tuple(SB.parameters); sbs = Tuple(SBS.parameters) + + # Matrix dims are first two; batch dims are trailing. + M = sa[1]; K = sa[2]; N = sb[2] + Ksa = sas[2] # a_scale K_s (a_scale is (M, K_s, batch...)) + Ksb = sbs[1] # b_scale K_s (b_scale is (K_s, N, batch...)) + a_batch = sa[3:end]; b_batch = sb[3:end] + as_batch = sas[3:end]; bs_batch = sbs[3:end] + + # Broadcast batch dims (pad shorter with trailing 1s, then broadcast). + n_batch = max(length(a_batch), length(b_batch)) + a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...) + b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...) + as_batch_padded = (as_batch..., ntuple(Returns(1), n_batch - length(as_batch))...) + bs_batch_padded = (bs_batch..., ntuple(Returns(1), n_batch - length(bs_batch))...) + batch_shape = map(max, a_batch_padded, b_batch_padded) + B_flat = prod(batch_shape) + + quote + # Reshape + broadcast to align batch dims (still trailing). + a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...))) + b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...))) + as_bc = broadcast_to(reshape(a_scale, $((M, Ksa, as_batch_padded...))), $((M, Ksa, batch_shape...))) + bs_bc = broadcast_to(reshape(b_scale, $((Ksb, N, bs_batch_padded...))), $((Ksb, N, batch_shape...))) + acc_bc = broadcast_to(acc, $((M, N, batch_shape...))) + # Flatten batch dims to one โ€” no permute needed since row-major Tile IR + # already has batch as the leading (slowest) dimension. + a_3d = reshape(a_bc, $((M, K, B_flat))) + b_3d = reshape(b_bc, $((K, N, B_flat))) + as_3d = reshape(as_bc, $((M, Ksa, B_flat))) + bs_3d = reshape(bs_bc, $((Ksb, N, B_flat))) + acc_3d = reshape(acc_bc, $((M, N, B_flat))) + # mmaf_scaled with swapped operands for row-major convention. + result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d) + # Unflatten batch dims. + reshape(result_3d, $((M, N, batch_shape...))) + end +end + # Matrix multiplication: A * B = muladd(A, B, zeros) # Note: SA, SB type parameters required to avoid ambiguity with scalar*tile methods during codegen # diff --git a/test/extensions/DLFP8Types.jl b/test/extensions/DLFP8Types.jl index cb72d48f..2f0fb0a2 100644 --- a/test/extensions/DLFP8Types.jl +++ b/test/extensions/DLFP8Types.jl @@ -2,6 +2,7 @@ using CUDA using DLFP8Types: Float8_E4M3FN, Float8_E5M2 spec1d = ct.ArraySpec{1}(16, true) +spec2d = ct.ArraySpec{2}(16, true) @testset "codegen" begin @@ -31,11 +32,24 @@ end end end +# Non-scaled f8 matmul lowers to `mmaf` (f8 operands, f32 accumulator). +@test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)))) + return + end +end + end -# FP8 types are Blackwell-only +# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+). @testset "execution" begin -if capability(device()) >= v"10" +if capability(device()) >= v"9" # Round-trip Float32 โ†’ FP8 โ†’ Float32 on values exactly representable in # the target FP8 type โ€” result must match input bit-for-bit. @@ -82,5 +96,29 @@ let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, - @test Array(d) == av .* bv .+ cv end +# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32). +function mma_dl_f32(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_dl_f16(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +@testset "mma โ†’ $Tacc acc" for (Tacc, kern) in ((Float32, mma_dl_f32), (Float16, mma_dl_f16)) + M = 16 + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + ch = Tacc.(Float32.(rand(0:2, M, M))) + ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 kern(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test Array(D) == ref +end + end end diff --git a/test/extensions/Microfloats.jl b/test/extensions/Microfloats.jl deleted file mode 100644 index 13b64c0a..00000000 --- a/test/extensions/Microfloats.jl +++ /dev/null @@ -1,229 +0,0 @@ -using CUDA -using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN - -spec1d = ct.ArraySpec{1}(16, true) - -@testset "codegen" begin - -# Float32 -> Float8_E4M3FN (always available; 13.1+) -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - @check "ftof" - converted = convert(ct.Tile{Float8_E4M3FN}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end -end - -# Float32 -> Float8_E5M2 (always available; 13.1+) -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - @check "ftof" - converted = convert(ct.Tile{Float8_E5M2}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end -end - -# Float32 -> Float8_E8M0FNU works on bytecode 13.2+ -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.2") do a, b - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - @check "ftof" - converted = convert(ct.Tile{Float8_E8M0FNU}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end -end - -# Float8_E8M0FNU rejected on bytecode 13.1 with a clear version error -let kernel = (a, b) -> begin - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - converted = convert(ct.Tile{Float8_E8M0FNU}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end - @test_throws "v13.2+" code_tiled(devnull, kernel, - Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.1") -end - -# Float4_E2M1FN requires bytecode 13.3 โ€” rejected at 13.2 with a clear error -let kernel = (a, b) -> begin - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - converted = convert(ct.Tile{Float4_E2M1FN}, tile) - ct.store(b, pid, convert(ct.Tile{Float32}, converted)) - return - end - @test_throws "v13.3+" code_tiled(devnull, kernel, - Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.2") -end - -# Whole-tile `reinterpret` between UInt8 and Float4_E2M1FN packs/unpacks two FP4 -# per byte: a `Tile{UInt8,Tuple{8}}` unpacks to a `Tile{Float4_E2M1FN,Tuple{16}}`, -# lowering to `cuda_tile.unpack` (13.3+). -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; - bytecode_version=v"13.3") do a, b - pid = ct.bid(1) - bytes = ct.load(a, pid, (8,)) # Tile{UInt8,Tuple{8}} - @check "unpack" - fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,Tuple{16}} - ct.store(b, pid, convert(ct.Tile{Float32}, fp4)) - return - end -end - -# And the reverse packs FP4 back into bytes via `cuda_tile.pack` (13.3+). -@test @filecheck begin - @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}; - bytecode_version=v"13.3") do a, b - pid = ct.bid(1) - vals = ct.load(a, pid, (16,)) - fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) # Tile{Float4_E2M1FN,Tuple{16}} - @check "pack" - ct.store(b, pid, reinterpret(UInt8, fp4)) # Tile{UInt8,Tuple{8}} - return - end -end - -end - -# FP8 types are Blackwell-only -@testset "execution" begin -if capability(device()) >= v"10" - -# Round-trip Float32 โ†’ microfloat โ†’ Float32 on values exactly representable -# in the target type โ€” result must match input bit-for-bit. -function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E4M3FN}, tile))) - return -end -function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) - pid = ct.bid(1) - tile = ct.load(a, pid, (16,)) - ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) - return -end -# Float8_E4M3FN / Float8_E5M2: 13.1+, always available -representable8 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, - 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] -let a = CuArray(representable8), b = CUDA.zeros(Float32, length(representable8)) - @cuda backend=cuTile blocks=1 rt_e4m3(a, b) - @test Array(b) == representable8 - @cuda backend=cuTile blocks=1 rt_e5m2(a, b) - @test Array(b) == representable8 -end - -# FMA in FP8: load Float32, convert to FP8, multiply-add in FP8, convert back. -# Uses inputs whose products and sums also stay representable, so the result -# is exact. -function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, - c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) - pid = ct.bid(1) - ta = convert(ct.Tile{Float8_E4M3FN}, ct.load(a, pid, (16,))) - tb = convert(ct.Tile{Float8_E4M3FN}, ct.load(b, pid, (16,))) - tc = convert(ct.Tile{Float8_E4M3FN}, ct.load(c, pid, (16,))) - ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc))) - return -end -let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], - bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], - cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] - a, b, c = CuArray(av), CuArray(bv), CuArray(cv) - d = CUDA.zeros(Float32, length(av)) - @cuda backend=cuTile blocks=1 fma_e4m3(a, b, c, d) - @test Array(d) == av .* bv .+ cv -end - -# Float8_E8M0FNU and Float4_E2M1FN: codegen for `ftof` lowers fine (see the -# codegen tests above) but tileiras refuses to lower a *standalone* f32 โ†” -# microfloat conversion on Blackwell โ€” these formats only have meaningful -# hardware paths as the scale/operand dtypes of block-scaled mma, or packed -# into a byte-wide tile via `reinterpret` (below). - -# Float4_E2M1FN moves through global memory packed two-per-byte: `reinterpret` -# unpacks a `UInt8` tile into FP4 (doubling the leading dim) and packs it back. - -# Pure pack/unpack round-trip: reinterpreting UInt8 โ†’ FP4 โ†’ UInt8 only -# reinterprets bits, so the bytes must come back unchanged for any input. -function rt_pack(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt8,1}) - pid = ct.bid(1) - bytes = ct.load(a, pid, (8,)) - fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 โ†’ (16,) FP4 - ct.store(b, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 โ†’ (8,) UInt8 - return -end -let av = UInt8[0x00, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde], - b = CUDA.zeros(UInt8, 8) - a = CuArray(av) - @cuda backend=cuTile blocks=1 rt_pack(a, b) - @test Array(b) == av -end - -# Value round-trip through FP4 stored as UInt8: convert Float32 โ†’ FP4, pack to -# UInt8 and store; then load UInt8, unpack to FP4 and convert back to Float32. -# All inputs are exactly representable in E2M1 (magnitudes 0,0.5,1,1.5,2,3,4,6), -# so the result must match bit-for-bit. -function pack_fp4(src::ct.TileArray{Float32,1}, dst::ct.TileArray{UInt8,1}) - pid = ct.bid(1) - vals = ct.load(src, pid, (16,)) - fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) - ct.store(dst, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 โ†’ (8,) UInt8 - return -end -function unpack_fp4(src::ct.TileArray{UInt8,1}, dst::ct.TileArray{Float32,1}) - pid = ct.bid(1) - bytes = ct.load(src, pid, (8,)) - fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 โ†’ (16,) FP4 - ct.store(dst, pid, convert(ct.Tile{Float32}, fp4)) - return -end -let representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] - src = CuArray(representable4) - packed = CUDA.zeros(UInt8, 8) - out = CUDA.zeros(Float32, 16) - @cuda backend=cuTile blocks=1 pack_fp4(src, packed) - @cuda backend=cuTile blocks=1 unpack_fp4(packed, out) - @test Array(out) == representable4 -end - -# N-D reinterpret: pack/unpack are rank-1, but whole-tile `reinterpret` flattens -# (via reshape) so it works on any rank. Round-trip a 2-D Float32 tile through -# 2-D FP4 and 2-D packed UInt8: the leading (column-major) dim absorbs the 2ร— / -# ยฝ scaling โ€” (8,2) FP4 โ†” (4,2) UInt8. -function rt_fp4_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) - pid = ct.bid(1) - fp4 = convert(ct.Tile{Float4_E2M1FN}, ct.load(src, pid, (8, 2))) # (8,2) FP4 - bytes = reinterpret(UInt8, fp4) # (4,2) UInt8 - fp4b = reinterpret(Float4_E2M1FN, bytes) # (8,2) FP4 - ct.store(dst, pid, convert(ct.Tile{Float32}, fp4b)) - return -end -let m = reshape(Float32[0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, - -2.0, -3.0, -4.0, -6.0, 0.5, 1.0, 2.0, 3.0], 8, 2) - src = CuArray(m) - out = CUDA.zeros(Float32, 8, 2) - @cuda backend=cuTile blocks=1 rt_fp4_2d(src, out) - @test Array(out) == m -end - -end -end diff --git a/test/extensions/Microfloats/codegen.jl b/test/extensions/Microfloats/codegen.jl new file mode 100644 index 00000000..ce9793cf --- /dev/null +++ b/test/extensions/Microfloats/codegen.jl @@ -0,0 +1,203 @@ +using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN + +spec1d = ct.ArraySpec{1}(16, true) +spec2d = ct.ArraySpec{2}(16, true) + +@testset "Microfloats codegen" begin + +@testset "ftof" begin + # Float32 -> Float8_E4M3FN (always available; 13.1+) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{Float8_E4M3FN}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Float32 -> Float8_E5M2 (always available; 13.1+) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{Float8_E5M2}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Float32 -> Float8_E8M0FNU works on bytecode 13.2+ + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.2") do a, b + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + @check "ftof" + converted = convert(ct.Tile{Float8_E8M0FNU}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + end + + # Float8_E8M0FNU rejected on bytecode 13.1 with a clear version error + let kernel = (a, b) -> begin + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + converted = convert(ct.Tile{Float8_E8M0FNU}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + @test_throws "v13.2+" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.1") + end + + # Float4_E2M1FN requires bytecode 13.3 โ€” rejected at 13.2 with a clear error + let kernel = (a, b) -> begin + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + converted = convert(ct.Tile{Float4_E2M1FN}, tile) + ct.store(b, pid, convert(ct.Tile{Float32}, converted)) + return + end + @test_throws "v13.3+" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.2") + end +end + +@testset "reinterpret" begin + # Whole-tile `reinterpret` between UInt8 and Float4_E2M1FN packs/unpacks two + # FP4 per byte: a `Tile{UInt8,(8,)}` unpacks to a `Tile{Float4_E2M1FN,(16,)}`, + # lowering to `cuda_tile.unpack` (13.3+). + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{Float32,1,spec1d}}; + bytecode_version=v"13.3") do a, b + pid = ct.bid(1) + bytes = ct.load(a, pid, (8,)) # Tile{UInt8,(8,)} + @check "unpack" + fp4 = reinterpret(Float4_E2M1FN, bytes) # Tile{Float4_E2M1FN,(16,)} + ct.store(b, pid, convert(ct.Tile{Float32}, fp4)) + return + end + end + + # And the reverse packs FP4 back into bytes via `cuda_tile.pack` (13.3+). + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{UInt8,1,spec1d}}; + bytecode_version=v"13.3") do a, b + pid = ct.bid(1) + vals = ct.load(a, pid, (16,)) + fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) # Tile{Float4_E2M1FN,(16,)} + @check "pack" + ct.store(b, pid, reinterpret(UInt8, fp4)) # Tile{UInt8,(8,)} + return + end + end +end + +@testset "mma" begin + # f8e4m3fn operands with an f32 accumulator lower to `mmaf` (13.1+). + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)))) + return + end + end + + # f8e5m2 operands with an f16 accumulator also lower to `mmaf`. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E5M2,2,spec2d}, ct.TileArray{Float8_E5M2,2,spec2d}, + ct.TileArray{Float16,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float16, (16, 16)))) + return + end + end + + # A disallowed accumulator dtype for f8 (only f16/f32 are valid) is rejected + # with a clear error rather than producing an mmaf op tileiras would reject. + @test_throws "tileiras requires acc" code_tiled( + Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float64,2,spec2d}}) do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + acc = zeros(Float64, (16, 16)) + Base.donotdelete(ct.Intrinsics.mma(ta, tb, acc)) + return + end +end + +@testset "mma_scaled" begin + # f8 operands (M,K)/(K,N) with f8e8m0fnu block scales (M,K_s)/(K_s,N) + # accumulate into f32. Block size B = K รท K_s = 64 รท 2 = 32. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; + bytecode_version=v"13.3") do x, x_scale, y, y_scale, z + xt = ct.load(x, (1, 1), (16, 64)) + xst = ct.load(x_scale, (1, 1), (16, 2)) + yt = ct.load(y, (1, 1), (64, 16)) + yst = ct.load(y_scale, (1, 1), (2, 16)) + @check "mmaf_scaled" + result = ct.muladd_scaled(xt, xst, yt, yst, zeros(Float32, (16, 16))) + ct.store(z, (1, 1), result) + return + end + end + + # f4e2m1fn operands are accepted by `mmaf_scaled` too (here with f8e4m3fn + # scales, B = 16). Operands enter as unpacked FP4 tiles. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; + bytecode_version=v"13.3") do x, x_scale, y, y_scale, z + xt = convert(ct.Tile{Float4_E2M1FN}, ct.load(x, (1, 1), (16, 64))) + xst = ct.load(x_scale, (1, 1), (16, 4)) + yt = convert(ct.Tile{Float4_E2M1FN}, ct.load(y, (1, 1), (64, 16))) + yst = ct.load(y_scale, (1, 1), (4, 16)) + @check "mmaf_scaled" + result = ct.muladd_scaled(xt, xst, yt, yst, zeros(Float32, (16, 16))) + ct.store(z, (1, 1), result) + return + end + end + + # mma_scaled requires bytecode 13.3 โ€” rejected at 13.2 with a clear error. + let kernel = (x, x_scale, y, y_scale, z) -> begin + xt = ct.load(x, (1, 1), (16, 64)) + xst = ct.load(x_scale, (1, 1), (16, 2)) + yt = ct.load(y, (1, 1), (64, 16)) + yst = ct.load(y_scale, (1, 1), (2, 16)) + ct.store(z, (1, 1), ct.muladd_scaled(xt, xst, yt, yst, zeros(Float32, (16, 16)))) + return + end + @test_throws "13.3" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E8M0FNU,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.2") + end +end + +end diff --git a/test/extensions/Microfloats/device.jl b/test/extensions/Microfloats/device.jl new file mode 100644 index 00000000..a8327a9d --- /dev/null +++ b/test/extensions/Microfloats/device.jl @@ -0,0 +1,323 @@ +using CUDA +using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN + +# Kernels are plain top-level functions (not closures). + +# ftof round-trips +function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E4M3FN}, tile))) + return +end +function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) + return +end +function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) + pid = ct.bid(1) + ta = convert(ct.Tile{Float8_E4M3FN}, ct.load(a, pid, (16,))) + tb = convert(ct.Tile{Float8_E4M3FN}, ct.load(b, pid, (16,))) + tc = convert(ct.Tile{Float8_E4M3FN}, ct.load(c, pid, (16,))) + ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc))) + return +end + +# reinterpret (pack / unpack) +function rt_pack(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt8,1}) + pid = ct.bid(1) + bytes = ct.load(a, pid, (8,)) + fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 โ†’ (16,) FP4 + ct.store(b, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 โ†’ (8,) UInt8 + return +end +function pack_fp4(src::ct.TileArray{Float32,1}, dst::ct.TileArray{UInt8,1}) + pid = ct.bid(1) + vals = ct.load(src, pid, (16,)) + fp4 = convert(ct.Tile{Float4_E2M1FN}, vals) + ct.store(dst, pid, reinterpret(UInt8, fp4)) # pack: (16,) FP4 โ†’ (8,) UInt8 + return +end +function unpack_fp4(src::ct.TileArray{UInt8,1}, dst::ct.TileArray{Float32,1}) + pid = ct.bid(1) + bytes = ct.load(src, pid, (8,)) + fp4 = reinterpret(Float4_E2M1FN, bytes) # unpack: (8,) UInt8 โ†’ (16,) FP4 + ct.store(dst, pid, convert(ct.Tile{Float32}, fp4)) + return +end +function rt_fp4_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) + pid = ct.bid(1) + fp4 = convert(ct.Tile{Float4_E2M1FN}, ct.load(src, pid, (8, 2))) # (8,2) FP4 + bytes = reinterpret(UInt8, fp4) # (4,2) UInt8 + fp4b = reinterpret(Float4_E2M1FN, bytes) # (8,2) FP4 + ct.store(dst, pid, convert(ct.Tile{Float32}, fp4b)) + return +end + +# non-scaled FP8 matmul (f8 ร— f8 with f16 or f32 accumulator) +function mma_e4m3_f32(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_e4m3_f16(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_e5m2_f32(A::ct.TileArray{Float8_E5M2,2}, B::ct.TileArray{Float8_E5M2,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_e5m2_f16(A::ct.TileArray{Float8_E5M2,2}, B::ct.TileArray{Float8_E5M2,2}, + C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_e4m3_star(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, a * b)) + return +end + +# block-scaled mma (B = K รท K_s; here K = 64, K_s = 2 โ†’ B = 32) +function mma_scaled_2d_e4m3(X::ct.TileArray{Float8_E4M3FN,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{Float8_E4M3FN,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, (1, 1), (64, 16)); ys = ct.load(YS, (1, 1), (2, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end +function mma_scaled_2d_e5m2(X::ct.TileArray{Float8_E5M2,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{Float8_E5M2,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, (1, 1), (64, 16)); ys = ct.load(YS, (1, 1), (2, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end +function mma_scaled_matvec(X::ct.TileArray{Float8_E4M3FN,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{Float8_E4M3FN,1}, YS::ct.TileArray{Float8_E8M0FNU,1}, + Z::ct.TileArray{Float32,1}) + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, 1, (64,)); ys = ct.load(YS, 1, (2,)) + ct.store(Z, 1, ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16,)))) + return +end +function mma_scaled_batched(X::ct.TileArray{Float8_E4M3FN,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{Float8_E4M3FN,3}, YS::ct.TileArray{Float8_E8M0FNU,3}, + Z::ct.TileArray{Float32,3}) + x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) + y = ct.load(Y, (1, 1, 1), (64, 16, 2)); ys = ct.load(YS, (1, 1, 1), (2, 16, 2)) + ct.store(Z, (1, 1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16, 2)))) + return +end +# FP4 operands have no direct sub-byte load: they arrive packed two-per-byte +# along the matmul-contiguous axis (K for X, N for Y, matching the row-major +# reference). cuTile's `reinterpret` doubles the *leading* (column-major) dim, +# so we load each operand with its packed axis leading, unpack, then transpose +# into the (M,K) / (K,N) orientation `muladd_scaled` expects: +# X: bytes (K/2, M) โ†’ reinterpret (K, M) โ†’ transpose (M, K) +# Y: bytes (N/2, K) โ†’ reinterpret (N, K) โ†’ transpose (K, N) +# Scales are f8 (one per byte) and load directly. +function mma_scaled_fp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{UInt8,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) + x = permutedims(reinterpret(Float4_E2M1FN, ct.load(X, (1, 1), (32, 16))), (2, 1)) # (M,K) + y = permutedims(reinterpret(Float4_E2M1FN, ct.load(Y, (1, 1), (8, 64))), (2, 1)) # (K,N) + xs = ct.load(XS, (1, 1), (16, 2)) + ys = ct.load(YS, (1, 1), (2, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end + +# Expand a block-scale tile to per-element scales along the K dimension: each +# scale entry covers B contiguous K positions (matches Tile IR's broadcast). +expand_k(scale, B, ::Val{:cols}) = repeat(Float32.(scale), inner=(1, B)) # (M, K_s) โ†’ (M, K) +expand_k(scale, B, ::Val{:rows}) = repeat(Float32.(scale), inner=(B, 1)) # (K_s, N) โ†’ (K, N) + +# Reference: (x โŠ™ x_scale) * (y โŠ™ y_scale) + acc, with acc = 0. +function mma_scaled_ref(xh, xs, yh, ys, B) + (Float32.(xh) .* expand_k(xs, B, Val(:cols))) * + (Float32.(yh) .* expand_k(ys, B, Val(:rows))) +end + +# Low 4 bits of an FP4 value's byte storage (its E2M1 nibble). +f4_nibble(v) = reinterpret(UInt8, Float4_E2M1FN(v)) & 0x0f + +# All FP4-representable magnitudes, for sampling exact inputs. +const F4_VALUES = Float32[0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.5, -1, -1.5, -2, -3, -4, -6] + +@testset "Microfloats device" begin + +# FP8 (e4m3/e5m2) conversions need Hopper (sm_90+). E8M0FNU/Float4_E2M1FN have +# no standalone f32 conversion path on any arch (codegen lowers fine, but +# tileiras rejects a standalone f32 โ†” E8M0/F4 conversion), so they are only +# meaningful as packed/scaled operands โ€” exercised by the testsets below. +@testset "ftof" begin +if capability(device()) >= v"9" + # Round-trip Float32 โ†’ microfloat โ†’ Float32 on exactly-representable values. + representable8 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, + 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] + let a = CuArray(representable8), b = CUDA.zeros(Float32, length(representable8)) + @cuda backend=cuTile blocks=1 rt_e4m3(a, b) + @test Array(b) == representable8 + @cuda backend=cuTile blocks=1 rt_e5m2(a, b) + @test Array(b) == representable8 + end + + # FMA in FP8: load f32, convert to FP8, multiply-add in FP8, convert back. + # Inputs whose products and sums stay representable, so the result is exact. + let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], + bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], + cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] + a, b, c = CuArray(av), CuArray(bv), CuArray(cv) + d = CUDA.zeros(Float32, length(av)) + @cuda backend=cuTile blocks=1 fma_e4m3(a, b, c, d) + @test Array(d) == av .* bv .+ cv + end +end +end + +# Float4_E2M1FN requires Blackwell (sm_100+). +@testset "reinterpret" begin +if capability(device()) >= v"10" + # Pure pack/unpack round-trip: UInt8 โ†’ FP4 โ†’ UInt8 must be a no-op. + let av = UInt8[0x00, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde], + b = CUDA.zeros(UInt8, 8) + a = CuArray(av) + @cuda backend=cuTile blocks=1 rt_pack(a, b) + @test Array(b) == av + end + + # Value round-trip through FP4 stored as UInt8 (all inputs representable). + let representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] + src = CuArray(representable4) + packed = CUDA.zeros(UInt8, 8) + out = CUDA.zeros(Float32, 16) + @cuda backend=cuTile blocks=1 pack_fp4(src, packed) + @cuda backend=cuTile blocks=1 unpack_fp4(packed, out) + @test Array(out) == representable4 + end + + # N-D reinterpret: whole-tile `reinterpret` flattens, so it works on any + # rank. (8,2) FP4 โ†” (4,2) UInt8 โ€” the leading dim absorbs the 2ร— / ยฝ. + let m = reshape(Float32[0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, + -2.0, -3.0, -4.0, -6.0, 0.5, 1.0, 2.0, 3.0], 8, 2) + src = CuArray(m) + out = CUDA.zeros(Float32, 8, 2) + @cuda backend=cuTile blocks=1 rt_fp4_2d(src, out) + @test Array(out) == m + end +end +end + +# Non-scaled FP8 matmul needs only Hopper (sm_90+); block-scaled mma (below) is +# the Blackwell-only variant. +@testset "mma" begin +if capability(device()) >= v"9" + M, K, N = 16, 16, 16 + # f8 operands with both allowed accumulator dtypes (f16 and f32). Inputs are + # in {0, 0.5, 1} and K = 16, so every partial sum is exactly representable + # in f16 too โ€” the f16-acc result matches the f32 reference bit-for-bit. + @testset "$T ร— $T โ†’ $Tacc acc" for (T, Tacc, kern) in ( + (Float8_E4M3FN, Float32, mma_e4m3_f32), + (Float8_E4M3FN, Float16, mma_e4m3_f16), + (Float8_E5M2, Float32, mma_e5m2_f32), + (Float8_E5M2, Float16, mma_e5m2_f16)) + ah = T.(Float32.(rand(0:2, M, K)) ./ 2) + bh = T.(Float32.(rand(0:2, K, N)) ./ 2) + ch = Tacc.(Float32.(rand(0:2, M, N))) + ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) + D = CUDA.zeros(Float32, M, N) + @cuda backend=cuTile blocks=1 kern(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test Array(D) == ref + end + + @testset "* operator (auto acc โ†’ f8)" begin + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + # f8 ร— f8 accumulates in f16, then downcasts back to f8. + ref = Float32.(Float8_E4M3FN.(Float16.(Float32.(ah) * Float32.(bh)))) + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 mma_e4m3_star(CuArray(ah), CuArray(bh), D) + @test Array(D) == ref + end +end +end + +# Block-scaled mma is Blackwell-only (sm_100+), for both FP8 and FP4 operands. +@testset "mma_scaled" begin +if capability(device()) >= v"10" + m, n, k, ks, B = 16, 16, 64, 2, 32 + + # 2D ร— 2D for each f8 operand type (scale = f8e8m0fnu, B = 32). + @testset "2D $T" for (T, kern) in ((Float8_E4M3FN, mma_scaled_2d_e4m3), + (Float8_E5M2, mma_scaled_2d_e5m2)) + xh = T.(Float32.(rand(0:2, m, k)) ./ 2) # values in {0, 0.5, 1} + yh = T.(Float32.(rand(0:2, k, n)) ./ 2) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) # scales in {1, 2} + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n)) + ref = mma_scaled_ref(xh, xs, yh, ys, B) + Z = CUDA.zeros(Float32, m, n) + @cuda backend=cuTile blocks=1 kern(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @test Array(Z) == ref + end + + @testset "mat-vec" begin + xh = Float8_E4M3FN.(Float32.(rand(0:2, m, k)) ./ 2) + yh = Float8_E4M3FN.(Float32.(rand(0:2, k)) ./ 2) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks)) + ref = (Float32.(xh) .* expand_k(xs, B, Val(:cols))) * + (Float32.(yh) .* repeat(Float32.(ys), inner=(B,))) + Z = CUDA.zeros(Float32, m) + @cuda backend=cuTile blocks=1 mma_scaled_matvec(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @test Array(Z) == ref + end + + @testset "batched (trailing batch broadcast)" begin + bt = 2 + xh = Float8_E4M3FN.(Float32.(rand(0:2, m, k)) ./ 2) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) + yh = Float8_E4M3FN.(Float32.(rand(0:2, k, n, bt)) ./ 2) + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n, bt)) + ref = stack(mma_scaled_ref(xh, xs, yh[:, :, bi], ys[:, :, bi], B) for bi in 1:bt) + Z = CUDA.zeros(Float32, m, n, bt) + @cuda backend=cuTile blocks=1 mma_scaled_batched(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @test Array(Z) == ref + end + + # FP4 operands packed two-per-byte, unpacked + transposed in the kernel + # (see `mma_scaled_fp4`). X packs along K โ†’ bytes (K/2, M); Y packs along + # N โ†’ bytes (N/2, K). + @testset "FP4 (packed, unpack + transpose)" begin + xh = [rand(F4_VALUES) for _ in 1:m, _ in 1:k] # (M, K) + yh = [rand(F4_VALUES) for _ in 1:k, _ in 1:n] # (K, N) + x_bytes = UInt8[f4_nibble(xh[i, 2c-1]) | (f4_nibble(xh[i, 2c]) << 4) + for c in 1:(k รท 2), i in 1:m] # (K/2, M) + y_bytes = UInt8[f4_nibble(yh[kk, 2c-1]) | (f4_nibble(yh[kk, 2c]) << 4) + for c in 1:(n รท 2), kk in 1:k] # (N/2, K) + xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) + ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n)) + ref = mma_scaled_ref(xh, xs, yh, ys, B) + Z = CUDA.zeros(Float32, m, n) + @cuda backend=cuTile blocks=1 mma_scaled_fp4(CuArray(x_bytes), CuArray(xs), + CuArray(y_bytes), CuArray(ys), Z) + @test Array(Z) == ref + end +end +end + +end From 38abdadf48641b856d23fa664915afeada17edd2 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 29 May 2026 11:49:55 +0200 Subject: [PATCH 4/6] Mark `muladd_scaled` public --- src/language/operations.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/language/operations.jl b/src/language/operations.jl index e89ac245..8480cc71 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -1394,6 +1394,8 @@ end Block-scaled matrix multiply-accumulate =============================================================================# +public muladd_scaled + """ muladd_scaled(a, a_scale, b, b_scale, acc) -> Tile From 674c5bbfbe96b9ed5a4aca16dded97e1731549e4 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 29 May 2026 17:15:38 +0200 Subject: [PATCH 5/6] Add FP8 `muladd` with `fast_acc` argument; add i8/u8 mma tests --- README.md | 18 +- ext/DLFP8TypesExt.jl | 4 + ext/MicrofloatsExt.jl | 4 + src/compiler/intrinsics/core.jl | 33 ++- src/language/operations.jl | 32 ++- test/device/integration.jl | 37 +++ test/extensions/DLFP8Types.jl | 94 ++++--- test/extensions/Microfloats/codegen.jl | 45 +++- test/extensions/Microfloats/device.jl | 351 ++++++++++++++----------- 9 files changed, 412 insertions(+), 206 deletions(-) diff --git a/README.md b/README.md index 2fad7de2..def5848f 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,23 @@ Tile IR operations. | Operation | Description | |-----------|-------------| | `a * b` | Matrix multiplication: `a @ b` | -| `muladd(a, b, acc)` | Matrix multiply-accumulate: `a * b + acc` | +| `muladd(a, b, acc; fast_acc=false)` | Matrix multiply-accumulate: `a * b + acc` | +| `ct.muladd_scaled(a, a_scale, b, b_scale, acc)` | Block-scaled multiply-accumulate | + +Each operation follows `Base.:*` / `Base.muladd`'s shape rules, with the addition of allowing trailing batch dimensions. + +`fast_acc=true` enables fast accumulation for FP8 inputs, and has an effect only on Hopper (sm_90; silently ignored on other +architectures), and requires Tile IR v13.3+. + +`ct.muladd_scaled` multiplies each operand by a low-precision block scale before the matmul: +each scale element covers a contiguous block of `B = K รท K_s` elements along the K dimension. +Requires Blackwell. The supported operand/scale/accumulator dtypes and block sizes are: + +| Input (`a`/`b`) | Scale | Acc/Output | B | +|-----------------|-------|------------|--------| +| `Float8_E4M3FN`, `Float8_E5M2` | `Float8_E8M0FNU` | `Float32` | 32 | +| `Float4_E2M1FN` | `Float8_E8M0FNU` | `Float32` | 16, 32 | +| `Float4_E2M1FN` | `Float8_E4M3FN` | `Float32` | 16 | ### Higher-Order Functions | Operation | Description | diff --git a/ext/DLFP8TypesExt.jl b/ext/DLFP8TypesExt.jl index 462636b5..841b44c3 100644 --- a/ext/DLFP8TypesExt.jl +++ b/ext/DLFP8TypesExt.jl @@ -18,6 +18,10 @@ end ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) +# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint. +ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true +ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true + # Float โ†” FP8 scalar constructor overlays (for map/convert dispatch) const FP8Types = (Float8_E4M3FN, Float8_E5M2) const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64) diff --git a/ext/MicrofloatsExt.jl b/ext/MicrofloatsExt.jl index 25f131aa..09c7b812 100644 --- a/ext/MicrofloatsExt.jl +++ b/ext/MicrofloatsExt.jl @@ -32,6 +32,10 @@ ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) +# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint. +ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true +ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true + # Float โ†” microfloat scalar constructor overlays (for map/convert dispatch). # Mirrors DLFP8TypesExt: route to `Intrinsics.ftof` so kernel-side conversions # lower to the FToFOp Tile IR intrinsic instead of Microfloats' Float32-fallback diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 940c08f4..0aeddcbc 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -415,6 +415,17 @@ mma_allowed_acc_dtypes(::Type{TFloat32}) = (Float32,) mma_allowed_acc_dtypes(::Type{Float64}) = (Float64,) mma_allowed_acc_dtypes(@nospecialize(::Type)) = nothing +""" + mma_supports_fast_acc(input_T) -> Bool + +Whether `mma`'s `fast_acc` hint (lower-precision accumulation for throughput) +is valid for operand element type `input_T`. Only the FP8 dtypes qualify; +extensions (Microfloats/DLFP8Types) register them by overloading this, like +[`mma_allowed_acc_dtypes`](@ref). Mirrors cuTile Python's +`use_fast_acc is only supported for fp8 input dtypes` check. +""" +mma_supports_fast_acc(@nospecialize(::Type)) = false + # First (preferred) accumulator dtype for a given input dtype. Used by # matmul to pick `acc = zeros(first_allowed_acc(T), โ€ฆ)` so that the input # dtype constraint and tileiras's acc-dtype constraint stay consistent @@ -425,7 +436,7 @@ first_allowed_acc_dtype(::Type{T}) where {T} = end """ - Intrinsics.mma(a::Tile, b::Tile, acc::Tile) -> typeof(acc) + Intrinsics.mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false) -> typeof(acc) Matrix-multiply-accumulate computing `a*b + acc`. Dispatches at codegen based on element types: @@ -436,15 +447,23 @@ based on element types: - `i8` `a`/`b` with `i32` `acc` lower to `cuda_tile.mmai`. Per-input signedness is derived from the Julia type (`Int8` โ†’ signed, `UInt8` โ†’ unsigned); `acc` and the result are always signed `i32`. + +`fast_acc` enables fast accumulation (trading accumulator precision for +throughput). It is only valid for FP8 inputs (see +[`mma_supports_fast_acc`](@ref)) and requires Tile IR v13.3+. """ -@intrinsic mma(a::Tile, b::Tile, acc::Tile) -tfunc(๐•ƒ, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc)) = CC.widenconst(acc) +@intrinsic mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false) +tfunc(๐•ƒ, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc), + @nospecialize(rest...)) = CC.widenconst(acc) function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) cb = ctx.cb lhs = emit_value!(ctx, args[1]) rhs = emit_value!(ctx, args[2]) acc = emit_value!(ctx, args[3]) + fast_acc = length(args) >= 4 ? + (@something get_constant(ctx, args[4]) throw(IRError("mma: fast_acc must be a compile-time constant"))) : + false (lhs === nothing || rhs === nothing || acc === nothing) && throw(IRError("Cannot resolve operands for mma()")) @@ -467,9 +486,15 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args) acc_elem in allowed_acc || throw(IRError("mma: acc dtype $acc_elem is not allowed for input dtype " * "$lhs_elem; tileiras requires acc โˆˆ $allowed_acc")) - encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v) + # fast_acc is an FP8-only throughput hint; reject it elsewhere with a + # clear error rather than emitting IR tileiras would reject. + fast_acc && !Base.invokelatest(mma_supports_fast_acc, lhs_elem) && + throw(IRError("mma: fast_acc is only supported for fp8 input dtypes " * + "(f8e4m3fn, f8e5m2), got $lhs_elem")) + encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v; fast_acc) elseif lhs_elem <: Union{Int8, UInt8} && rhs_elem <: Union{Int8, UInt8} && acc_elem === Int32 + fast_acc && throw(IRError("mma: fast_acc is not supported for integer (mmai) inputs")) s_lhs = lhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned s_rhs = rhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned encode_MmaIOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v; diff --git a/src/language/operations.jl b/src/language/operations.jl index 8480cc71..3a729b9c 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -1296,7 +1296,7 @@ end =============================================================================# """ - muladd(a::Tile, b::Tile, acc::Tile) -> Tile + muladd(a::Tile, b::Tile, acc::Tile; fast_acc::Bool=false) -> Tile Matrix multiply-accumulate `a * b + acc` over tiles, lowering to `cuda_tile.mmaf` (float) or `cuda_tile.mmai` (`i8 ร— i8 โ†’ i32`). @@ -1306,45 +1306,49 @@ promoted (vec-mat / mat-vec) and any trailing dimensions (โ‰ฅ3-D) are treated as broadcast batch dims, lifting `Base.muladd`'s shape rules to tiles. `acc` carries the result dtype, which must be one tileiras allows for the input dtype (f16/f32 for f16 and f8; f32 for bf16/tf32; f64 for f64; i32 for i8). + +`fast_acc` enables fast accumulation (lower accumulator precision for +throughput); it is valid only for FP8 inputs and requires Tile IR v13.3+. """ -@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}) where {T1, T2, T3, SA, SB, SC} +@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}; + fast_acc::Bool=false) where {T1, T2, T3, SA, SB, SC} # SA, SB, SC type parameters avoid ambiguity with the scalar `muladd` # methods during codegen. - _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b))) + _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b)), fast_acc) end # 2D ร— 2D: MmaFOp with swapped operands for row-major Tile IR # Julia (M,K)*(K,N) โ†’ TileIR (K,M)*(N,K) โ†’ mmaf(b,a,acc) โ†’ TileIR (N,M) โ†’ Julia (M,N) -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}) - Intrinsics.mma(b, a, acc) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool) + Intrinsics.mma(b, a, acc, fast_acc) end # Vec-mat (1D ร— 2D): reshape (M,) โ†’ (M, 1), MmaFOp, acc is already (M, N) -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool) a2d = reshape(a, (size(a, 1), 1)) - _muladd(a2d, b, acc, Val(2), Val(2)) + _muladd(a2d, b, acc, Val(2), Val(2), fast_acc) end # Mat-vec (2D ร— 1D): reshape b (K,) โ†’ (K, 1), acc (M,) โ†’ (M, 1), MmaFOp, squeeze back -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}) +@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool) M, K = size(a, 1), size(b, 1) b2d = reshape(b, (K, 1)) acc2d = reshape(acc, (M, 1)) - result = _muladd(a, b2d, acc2d, Val(2), Val(2)) + result = _muladd(a, b2d, acc2d, Val(2), Val(2), fast_acc) reshape(result, (M,)) end # Vec-vec (1D ร— 1D): not supported -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool) return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported."))) end # Batched mat-vec / vec-mat (โ‰ฅ3D ร— 1D or 1D ร— โ‰ฅ3D): not supported, unsqueeze manually -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB} NB >= 3 || return :(throw(ArgumentError("unreachable"))) return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first."))) end -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} +@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA} NA >= 3 || return :(throw(ArgumentError("unreachable"))) return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first."))) end @@ -1357,7 +1361,7 @@ end # 3. MmaFOp with swapped operands: mmaf(b, a, acc) # 4. Unflatten batch dims via reshape @generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}, - ::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB} + ::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB} sa = Tuple(SA.parameters) sb = Tuple(SB.parameters) @@ -1384,7 +1388,7 @@ end b_3d = reshape(b_bc, $((K, N, B_flat))) acc_3d = reshape(acc_bc, $((M, N, B_flat))) # MmaFOp with swapped operands for row-major convention - result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d) + result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc) # Unflatten batch dims reshape(result_3d, $((M, N, batch_shape...))) end diff --git a/test/device/integration.jl b/test/device/integration.jl index 01a7d85c..fcd50656 100644 --- a/test/device/integration.jl +++ b/test/device/integration.jl @@ -33,3 +33,40 @@ using CUDA @test c_cpu โ‰ˆ c_ref end + +@testset "i8/u8 matmul (mmai)" begin + # i8/u8 ร— i8/u8 โ†’ i32 lowers to `cuda_tile.mmai`. The accumulator must be + # Int32 (the `*` operator would pick the input dtype as acc and fail), so we + # use `muladd` with an explicit i32 acc. Per-operand signedness is derived + # from the Julia element type, so we feed values that differ under signed vs + # unsigned interpretation (negatives, and magnitudes > 127) to confirm each + # operand is interpreted correctly. K = 16, |product| โ‰ค 255ยฒ ยท 16 โ‰ˆ 1.0e6, + # well within Int32, so the result is exact. + function mmai(a::ct.TileArray{T1,2}, b::ct.TileArray{T2,2}, c::ct.TileArray{Int32,2}) where {T1,T2} + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + tc = muladd(ta, tb, zeros(Int32, (16, 16))) + ct.store(c, (1, 1), tc) + return + end + + M = K = N = 16 + @testset "signed ร— signed" begin + a = rand(Int8(-128):Int8(127), M, K); b = rand(Int8(-128):Int8(127), K, N) + c = CUDA.zeros(Int32, M, N) + @cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c) + @test Array(c) == Int32.(a) * Int32.(b) + end + @testset "unsigned ร— unsigned" begin + a = rand(UInt8(0):UInt8(255), M, K); b = rand(UInt8(0):UInt8(255), K, N) + c = CUDA.zeros(Int32, M, N) + @cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c) + @test Array(c) == Int32.(a) * Int32.(b) + end + @testset "unsigned ร— signed" begin + a = rand(UInt8(0):UInt8(255), M, K); b = rand(Int8(-128):Int8(127), K, N) + c = CUDA.zeros(Int32, M, N) + @cuda backend=cuTile blocks=1 mmai(CuArray(a), CuArray(b), c) + @test Array(c) == Int32.(a) * Int32.(b) + end +end diff --git a/test/extensions/DLFP8Types.jl b/test/extensions/DLFP8Types.jl index 2f0fb0a2..f18371e4 100644 --- a/test/extensions/DLFP8Types.jl +++ b/test/extensions/DLFP8Types.jl @@ -45,14 +45,27 @@ end end end +# `fast_acc=true` is an FP8-only hint; it still lowers to `mmaf` (13.3+). +@test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.3") do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end end -# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+). -@testset "execution" begin -if capability(device()) >= v"9" +end -# Round-trip Float32 โ†’ FP8 โ†’ Float32 on values exactly representable in -# the target FP8 type โ€” result must match input bit-for-bit. +# Execution kernels are plain top-level functions, each defined next to the +# test that exercises it. Kernels parametric on accumulator dtype must stay at +# top level โ€” defining them inside a testset scope boxes them into closures. + +# Round-trip Float32 โ†’ FP8 โ†’ Float32 on values exactly representable in the +# target FP8 type โ€” result must match input bit-for-bit. function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) pid = ct.bid(1) tile = ct.load(a, pid, (16,)) @@ -65,19 +78,8 @@ function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) return end - -representable = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, - 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] -let a = CuArray(representable), b = CUDA.zeros(Float32, length(representable)) - @cuda backend=cuTile blocks=1 rt_e4m3(a, b) - @test Array(b) == representable - @cuda backend=cuTile blocks=1 rt_e5m2(a, b) - @test Array(b) == representable -end - # FMA in FP8: load Float32, convert to FP8, multiply-add in FP8, convert back. -# Uses inputs whose products and sums also stay representable, so the result -# is exact. +# Inputs whose products and sums also stay representable, so the result is exact. function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) pid = ct.bid(1) @@ -87,6 +89,33 @@ function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, ct.store(d, pid, convert(ct.Tile{Float32}, muladd.(ta, tb, tc))) return end +# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32). +function mma_dl_fp8(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Tacc,2}, D::ct.TileArray{Float32,2}) where {Tacc<:Union{Float16,Float32}} + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) + return +end +function mma_dl_fast(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c; fast_acc=true))) + return +end + +# FP8 (e4m3/e5m2) conversions and matmul need Hopper (sm_90+). +@testset "execution" begin +if capability(device()) >= v"9" + +representable = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, + 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] +let a = CuArray(representable), b = CUDA.zeros(Float32, length(representable)) + @cuda backend=cuTile blocks=1 rt_e4m3(a, b) + @test Array(b) == representable + @cuda backend=cuTile blocks=1 rt_e5m2(a, b) + @test Array(b) == representable +end + let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] @@ -96,29 +125,30 @@ let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, - @test Array(d) == av .* bv .+ cv end -# Non-scaled FP8 matmul with both allowed accumulator dtypes (f16 and f32). -function mma_dl_f32(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, - C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) - a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) - ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) - return -end -function mma_dl_f16(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, - C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2}) - a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) - ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) - return -end -@testset "mma โ†’ $Tacc acc" for (Tacc, kern) in ((Float32, mma_dl_f32), (Float16, mma_dl_f16)) +@testset "mma โ†’ $Tacc acc" for Tacc in (Float32, Float16) M = 16 ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) ch = Tacc.(Float32.(rand(0:2, M, M))) ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) D = CUDA.zeros(Float32, M, M) - @cuda backend=cuTile blocks=1 kern(CuArray(ah), CuArray(bh), CuArray(ch), D) + @cuda backend=cuTile blocks=1 mma_dl_fp8(CuArray(ah), CuArray(bh), CuArray(ch), D) @test Array(D) == ref end +# fast_acc only has an effect on Hopper (sm_90); ignored elsewhere. So off +# Hopper we assert the exact result (the flag must ride through without +# perturbing the output); on Hopper we make no numeric claim. +@testset "mma fast_acc (exact off Hopper)" begin + M = 16 + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + ch = Float32.(rand(0:2, M, M)) + ref = Float32.(ah) * Float32.(bh) .+ ch + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 mma_dl_fast(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test (Array(D) == ref) || (v"9" <= capability(device()) < v"10") +end + end end diff --git a/test/extensions/Microfloats/codegen.jl b/test/extensions/Microfloats/codegen.jl index ce9793cf..efbdf698 100644 --- a/test/extensions/Microfloats/codegen.jl +++ b/test/extensions/Microfloats/codegen.jl @@ -145,8 +145,47 @@ end end end +@testset "fast_acc" begin + # `fast_acc=true` on f8 operands still lowers to `mmaf` (the hint rides on + # the op as a flag); requires bytecode 13.3. + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; + bytecode_version=v"13.3") do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + @check "mmaf" + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end + end + + # `fast_acc` is an FP8-only hint: requesting it for f16 inputs is rejected. + @test_throws "only supported for fp8" code_tiled( + Tuple{ct.TileArray{Float16,2,spec2d}, ct.TileArray{Float16,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.3") do a, b, c + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end + + # `fast_acc` requires bytecode 13.3 โ€” rejected at 13.2 with a clear error. + let kernel = (a, b, c) -> begin + ta = ct.load(a, (1, 1), (16, 16)) + tb = ct.load(b, (1, 1), (16, 16)) + ct.store(c, (1, 1), muladd(ta, tb, zeros(Float32, (16, 16)); fast_acc=true)) + return + end + @test_throws "13.3" code_tiled(devnull, kernel, + Tuple{ct.TileArray{Float8_E4M3FN,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, + ct.TileArray{Float32,2,spec2d}}; bytecode_version=v"13.2") + end +end + @testset "mma_scaled" begin - # f8 operands (M,K)/(K,N) with f8e8m0fnu block scales (M,K_s)/(K_s,N) + # MXFP8: f8 operands (M,K)/(K,N) with f8e8m0fnu block scales (M,K_s)/(K_s,N) # accumulate into f32. Block size B = K รท K_s = 64 รท 2 = 32. @test @filecheck begin @check_label "entry" @@ -165,8 +204,8 @@ end end end - # f4e2m1fn operands are accepted by `mmaf_scaled` too (here with f8e4m3fn - # scales, B = 16). Operands enter as unpacked FP4 tiles. + # NVFP4: f4e2m1fn operands with f8e4m3fn scales, B = 16. Operands enter as + # unpacked FP4 tiles; `mmaf_scaled` accepts them too. @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float8_E4M3FN,2,spec2d}, diff --git a/test/extensions/Microfloats/device.jl b/test/extensions/Microfloats/device.jl index a8327a9d..97ede9a5 100644 --- a/test/extensions/Microfloats/device.jl +++ b/test/extensions/Microfloats/device.jl @@ -1,7 +1,9 @@ using CUDA using Microfloats: Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU, Float4_E2M1FN -# Kernels are plain top-level functions (not closures). +# Kernels are plain top-level functions (not closures), each defined next to the +# testset that exercises it. Kernels parametric on operand dtype must stay at +# top level โ€” defining them inside a testset scope boxes them into closures. # ftof round-trips function rt_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) @@ -16,6 +18,18 @@ function rt_e5m2(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E5M2}, tile))) return end +function rt_e8m0(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float8_E8M0FNU}, tile))) + return +end +function rt_f4(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}) + pid = ct.bid(1) + tile = ct.load(a, pid, (16,)) + ct.store(b, pid, convert(ct.Tile{Float32}, convert(ct.Tile{Float4_E2M1FN}, tile))) + return +end function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1}, d::ct.TileArray{Float32,1}) pid = ct.bid(1) @@ -26,6 +40,52 @@ function fma_e4m3(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, return end +# Standalone f32 โ†’ microfloat โ†’ f32 conversion round-trips exactly for every +# microfloat type on representable inputs. FP8 (e4m3/e5m2) needs Hopper (sm_90+); +# E8M0FNU and Float4_E2M1FN need Blackwell (sm_100+). E8M0 is exponent-only, so +# its representable values are exact powers of two. +@testset "ftof" begin +if capability(device()) >= v"9" + # Round-trip Float32 โ†’ microfloat โ†’ Float32 on exactly-representable values. + representable8 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, + 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] + let a = CuArray(representable8), b = CUDA.zeros(Float32, length(representable8)) + @cuda backend=cuTile blocks=1 rt_e4m3(a, b) + @test Array(b) == representable8 + @cuda backend=cuTile blocks=1 rt_e5m2(a, b) + @test Array(b) == representable8 + end + + # FMA in FP8: load f32, convert to FP8, multiply-add in FP8, convert back. + # Inputs whose products and sums stay representable, so the result is exact. + let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], + bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], + cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] + a, b, c = CuArray(av), CuArray(bv), CuArray(cv) + d = CUDA.zeros(Float32, length(av)) + @cuda backend=cuTile blocks=1 fma_e4m3(a, b, c, d) + @test Array(d) == av .* bv .+ cv + end +end +if capability(device()) >= v"10" + # E8M0 round-trip: exponent-only, so representable values are powers of two. + representable_e8m0 = Float32(2) .^ Float32[-8, -6, -4, -2, -1, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 10, 16] + let a = CuArray(representable_e8m0), b = CUDA.zeros(Float32, length(representable_e8m0)) + @cuda backend=cuTile blocks=1 rt_e8m0(a, b) + @test Array(b) == representable_e8m0 + end + + # F4 round-trip via the standalone conversion path (no pack/reinterpret). + representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] + let a = CuArray(representable4), b = CUDA.zeros(Float32, length(representable4)) + @cuda backend=cuTile blocks=1 rt_f4(a, b) + @test Array(b) == representable4 + end +end +end + # reinterpret (pack / unpack) function rt_pack(a::ct.TileArray{UInt8,1}, b::ct.TileArray{UInt8,1}) pid = ct.bid(1) @@ -57,27 +117,44 @@ function rt_fp4_2d(src::ct.TileArray{Float32,2}, dst::ct.TileArray{Float32,2}) return end -# non-scaled FP8 matmul (f8 ร— f8 with f16 or f32 accumulator) -function mma_e4m3_f32(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, - C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) - a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) - ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) - return -end -function mma_e4m3_f16(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, - C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2}) - a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) - ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) - return +# Float4_E2M1FN requires Blackwell (sm_100+). +@testset "reinterpret" begin +if capability(device()) >= v"10" + # Pure pack/unpack round-trip: UInt8 โ†’ FP4 โ†’ UInt8 must be a no-op. + let av = UInt8[0x00, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde], + b = CUDA.zeros(UInt8, 8) + a = CuArray(av) + @cuda backend=cuTile blocks=1 rt_pack(a, b) + @test Array(b) == av + end + + # Value round-trip through FP4 stored as UInt8 (all inputs representable). + let representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] + src = CuArray(representable4) + packed = CUDA.zeros(UInt8, 8) + out = CUDA.zeros(Float32, 16) + @cuda backend=cuTile blocks=1 pack_fp4(src, packed) + @cuda backend=cuTile blocks=1 unpack_fp4(packed, out) + @test Array(out) == representable4 + end + + # N-D reinterpret: whole-tile `reinterpret` flattens, so it works on any + # rank. (8,2) FP4 โ†” (4,2) UInt8 โ€” the leading dim absorbs the 2ร— / ยฝ. + let m = reshape(Float32[0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, + -2.0, -3.0, -4.0, -6.0, 0.5, 1.0, 2.0, 3.0], 8, 2) + src = CuArray(m) + out = CUDA.zeros(Float32, 8, 2) + @cuda backend=cuTile blocks=1 rt_fp4_2d(src, out) + @test Array(out) == m + end end -function mma_e5m2_f32(A::ct.TileArray{Float8_E5M2,2}, B::ct.TileArray{Float8_E5M2,2}, - C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) - a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) - ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) - return end -function mma_e5m2_f16(A::ct.TileArray{Float8_E5M2,2}, B::ct.TileArray{Float8_E5M2,2}, - C::ct.TileArray{Float16,2}, D::ct.TileArray{Float32,2}) + +# non-scaled FP8 matmul (f8 ร— f8 with f16 or f32 accumulator) +function mma_fp8(A::ct.TileArray{T,2}, B::ct.TileArray{T,2}, + C::ct.TileArray{Tacc,2}, D::ct.TileArray{Float32,2} + ) where {T<:Union{Float8_E4M3FN,Float8_E5M2},Tacc<:Union{Float16,Float32}} a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c))) return @@ -88,19 +165,65 @@ function mma_e4m3_star(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_ ct.store(D, (1, 1), convert(ct.Tile{Float32}, a * b)) return end - -# block-scaled mma (B = K รท K_s; here K = 64, K_s = 2 โ†’ B = 32) -function mma_scaled_2d_e4m3(X::ct.TileArray{Float8_E4M3FN,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, - Y::ct.TileArray{Float8_E4M3FN,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, - Z::ct.TileArray{Float32,2}) - x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) - y = ct.load(Y, (1, 1), (64, 16)); ys = ct.load(YS, (1, 1), (2, 16)) - ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) +function mma_e4m3_fast(A::ct.TileArray{Float8_E4M3FN,2}, B::ct.TileArray{Float8_E4M3FN,2}, + C::ct.TileArray{Float32,2}, D::ct.TileArray{Float32,2}) + a = ct.load(A, (1, 1), (16, 16)); b = ct.load(B, (1, 1), (16, 16)); c = ct.load(C, (1, 1), (16, 16)) + ct.store(D, (1, 1), convert(ct.Tile{Float32}, muladd(a, b, c; fast_acc=true))) return end -function mma_scaled_2d_e5m2(X::ct.TileArray{Float8_E5M2,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, - Y::ct.TileArray{Float8_E5M2,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, - Z::ct.TileArray{Float32,2}) + +# Non-scaled FP8 matmul needs only Hopper (sm_90+); block-scaled mma (below) is +# the Blackwell-only variant. +@testset "mma" begin +if capability(device()) >= v"9" + M, K, N = 16, 16, 16 + # f8 operands with both allowed accumulator dtypes (f16 and f32). Inputs are + # in {0, 0.5, 1} and K = 16, so every partial sum is exactly representable + # in f16 too โ€” the f16-acc result matches the f32 reference bit-for-bit. + @testset "$T ร— $T โ†’ $Tacc acc" for (T, Tacc) in ( + (Float8_E4M3FN, Float32), (Float8_E4M3FN, Float16), + (Float8_E5M2, Float32), (Float8_E5M2, Float16)) + ah = T.(Float32.(rand(0:2, M, K)) ./ 2) + bh = T.(Float32.(rand(0:2, K, N)) ./ 2) + ch = Tacc.(Float32.(rand(0:2, M, N))) + ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) + D = CUDA.zeros(Float32, M, N) + @cuda backend=cuTile blocks=1 mma_fp8(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test Array(D) == ref + end + + @testset "* operator (auto acc โ†’ f8)" begin + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) + # f8 ร— f8 accumulates in f16, then downcasts back to f8. + ref = Float32.(Float8_E4M3FN.(Float16.(Float32.(ah) * Float32.(bh)))) + D = CUDA.zeros(Float32, M, M) + @cuda backend=cuTile blocks=1 mma_e4m3_star(CuArray(ah), CuArray(bh), D) + @test Array(D) == ref + end + + # fast_acc trades accumulator precision for throughput, but only has an + # effect on Hopper (sm_90); it is silently ignored on every other arch. So + # off Hopper we assert the exact result โ€” the flag must ride through tileiras + # and ptxas without perturbing the output. On Hopper we make no numeric + # claim: fast accumulation may legitimately diverge there, and we have no + # fast-accum reference to compare against. + @testset "fast_acc (exact off Hopper)" begin + ah = Float8_E4M3FN.(Float32.(rand(0:2, M, K)) ./ 2) + bh = Float8_E4M3FN.(Float32.(rand(0:2, K, N)) ./ 2) + ch = Float32.(rand(0:2, M, N)) + ref = Float32.(ah) * Float32.(bh) .+ ch + D = CUDA.zeros(Float32, M, N) + @cuda backend=cuTile blocks=1 mma_e4m3_fast(CuArray(ah), CuArray(bh), CuArray(ch), D) + @test (Array(D) == ref) || (v"9" <= capability(device()) < v"10") + end +end +end + +# block-scaled mma (B = K รท K_s; here K = 64, K_s = 2 โ†’ B = 32) +function mma_scaled_2d(X::ct.TileArray{T,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{T,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) where {T<:Union{Float8_E4M3FN,Float8_E5M2}} x = ct.load(X, (1, 1), (16, 64)); xs = ct.load(XS, (1, 1), (16, 2)) y = ct.load(Y, (1, 1), (64, 16)); ys = ct.load(YS, (1, 1), (2, 16)) ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) @@ -129,10 +252,11 @@ end # into the (M,K) / (K,N) orientation `muladd_scaled` expects: # X: bytes (K/2, M) โ†’ reinterpret (K, M) โ†’ transpose (M, K) # Y: bytes (N/2, K) โ†’ reinterpret (N, K) โ†’ transpose (K, N) -# Scales are f8 (one per byte) and load directly. -function mma_scaled_fp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, - Y::ct.TileArray{UInt8,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, - Z::ct.TileArray{Float32,2}) +# MXFP4 takes an f8e8m0fnu scale (B = 32 โ†’ K_s = 2); NVFP4 an f8e4m3fn scale +# (B = 16 โ†’ K_s = 4). Operand unpacking is identical โ€” only the scale differs. +function mma_scaled_mxfp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E8M0FNU,2}, + Y::ct.TileArray{UInt8,2}, YS::ct.TileArray{Float8_E8M0FNU,2}, + Z::ct.TileArray{Float32,2}) x = permutedims(reinterpret(Float4_E2M1FN, ct.load(X, (1, 1), (32, 16))), (2, 1)) # (M,K) y = permutedims(reinterpret(Float4_E2M1FN, ct.load(Y, (1, 1), (8, 64))), (2, 1)) # (K,N) xs = ct.load(XS, (1, 1), (16, 2)) @@ -140,6 +264,16 @@ function mma_scaled_fp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E8M0FN ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) return end +function mma_scaled_nvfp4(X::ct.TileArray{UInt8,2}, XS::ct.TileArray{Float8_E4M3FN,2}, + Y::ct.TileArray{UInt8,2}, YS::ct.TileArray{Float8_E4M3FN,2}, + Z::ct.TileArray{Float32,2}) + x = permutedims(reinterpret(Float4_E2M1FN, ct.load(X, (1, 1), (32, 16))), (2, 1)) # (M,K) + y = permutedims(reinterpret(Float4_E2M1FN, ct.load(Y, (1, 1), (8, 64))), (2, 1)) # (K,N) + xs = ct.load(XS, (1, 1), (16, 4)) + ys = ct.load(YS, (1, 1), (4, 16)) + ct.store(Z, (1, 1), ct.muladd_scaled(x, xs, y, ys, zeros(Float32, (16, 16)))) + return +end # Expand a block-scale tile to per-element scales along the K dimension: each # scale entry covers B contiguous K positions (matches Tile IR's broadcast). @@ -155,127 +289,33 @@ end # Low 4 bits of an FP4 value's byte storage (its E2M1 nibble). f4_nibble(v) = reinterpret(UInt8, Float4_E2M1FN(v)) & 0x0f +# Pack an (R, C) FP4 matrix two-per-byte along the contiguous (second) axis โ†’ +# (C/2, R) bytes: low nibble = even column, high nibble = odd column. Matches +# X (M,K)โ†’(K/2,M) and Y (K,N)โ†’(N/2,K) in the kernels above. +pack_f4_bytes(vals) = UInt8[f4_nibble(vals[r, 2c-1]) | (f4_nibble(vals[r, 2c]) << 4) + for c in 1:(size(vals, 2) รท 2), r in 1:size(vals, 1)] + # All FP4-representable magnitudes, for sampling exact inputs. const F4_VALUES = Float32[0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.5, -1, -1.5, -2, -3, -4, -6] -@testset "Microfloats device" begin - -# FP8 (e4m3/e5m2) conversions need Hopper (sm_90+). E8M0FNU/Float4_E2M1FN have -# no standalone f32 conversion path on any arch (codegen lowers fine, but -# tileiras rejects a standalone f32 โ†” E8M0/F4 conversion), so they are only -# meaningful as packed/scaled operands โ€” exercised by the testsets below. -@testset "ftof" begin -if capability(device()) >= v"9" - # Round-trip Float32 โ†’ microfloat โ†’ Float32 on exactly-representable values. - representable8 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 8.0, - 16.0, 32.0, 64.0, 128.0, 256.0, -1.0, -2.0, -0.5] - let a = CuArray(representable8), b = CUDA.zeros(Float32, length(representable8)) - @cuda backend=cuTile blocks=1 rt_e4m3(a, b) - @test Array(b) == representable8 - @cuda backend=cuTile blocks=1 rt_e5m2(a, b) - @test Array(b) == representable8 - end - - # FMA in FP8: load f32, convert to FP8, multiply-add in FP8, convert back. - # Inputs whose products and sums stay representable, so the result is exact. - let av = Float32[1.0, 2.0, 0.5, 4.0, 1.5, 2.0, -1.0, -0.5, 3.0, 0.5, 1.0, 2.0, -2.0, 1.0, 0.5, 4.0], - bv = Float32[2.0, 1.0, 4.0, 0.5, 2.0, 3.0, 2.0, 4.0, 1.0, 2.0, 1.0, 0.5, 2.0, 1.0, 2.0, 1.0], - cv = Float32[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0] - a, b, c = CuArray(av), CuArray(bv), CuArray(cv) - d = CUDA.zeros(Float32, length(av)) - @cuda backend=cuTile blocks=1 fma_e4m3(a, b, c, d) - @test Array(d) == av .* bv .+ cv - end -end -end - -# Float4_E2M1FN requires Blackwell (sm_100+). -@testset "reinterpret" begin -if capability(device()) >= v"10" - # Pure pack/unpack round-trip: UInt8 โ†’ FP4 โ†’ UInt8 must be a no-op. - let av = UInt8[0x00, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde], - b = CUDA.zeros(UInt8, 8) - a = CuArray(av) - @cuda backend=cuTile blocks=1 rt_pack(a, b) - @test Array(b) == av - end - - # Value round-trip through FP4 stored as UInt8 (all inputs representable). - let representable4 = Float32[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, 0.5] - src = CuArray(representable4) - packed = CUDA.zeros(UInt8, 8) - out = CUDA.zeros(Float32, 16) - @cuda backend=cuTile blocks=1 pack_fp4(src, packed) - @cuda backend=cuTile blocks=1 unpack_fp4(packed, out) - @test Array(out) == representable4 - end - - # N-D reinterpret: whole-tile `reinterpret` flattens, so it works on any - # rank. (8,2) FP4 โ†” (4,2) UInt8 โ€” the leading dim absorbs the 2ร— / ยฝ. - let m = reshape(Float32[0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -1.0, - -2.0, -3.0, -4.0, -6.0, 0.5, 1.0, 2.0, 3.0], 8, 2) - src = CuArray(m) - out = CUDA.zeros(Float32, 8, 2) - @cuda backend=cuTile blocks=1 rt_fp4_2d(src, out) - @test Array(out) == m - end -end -end - -# Non-scaled FP8 matmul needs only Hopper (sm_90+); block-scaled mma (below) is -# the Blackwell-only variant. -@testset "mma" begin -if capability(device()) >= v"9" - M, K, N = 16, 16, 16 - # f8 operands with both allowed accumulator dtypes (f16 and f32). Inputs are - # in {0, 0.5, 1} and K = 16, so every partial sum is exactly representable - # in f16 too โ€” the f16-acc result matches the f32 reference bit-for-bit. - @testset "$T ร— $T โ†’ $Tacc acc" for (T, Tacc, kern) in ( - (Float8_E4M3FN, Float32, mma_e4m3_f32), - (Float8_E4M3FN, Float16, mma_e4m3_f16), - (Float8_E5M2, Float32, mma_e5m2_f32), - (Float8_E5M2, Float16, mma_e5m2_f16)) - ah = T.(Float32.(rand(0:2, M, K)) ./ 2) - bh = T.(Float32.(rand(0:2, K, N)) ./ 2) - ch = Tacc.(Float32.(rand(0:2, M, N))) - ref = Float32.(ah) * Float32.(bh) .+ Float32.(ch) - D = CUDA.zeros(Float32, M, N) - @cuda backend=cuTile blocks=1 kern(CuArray(ah), CuArray(bh), CuArray(ch), D) - @test Array(D) == ref - end - - @testset "* operator (auto acc โ†’ f8)" begin - ah = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) - bh = Float8_E4M3FN.(Float32.(rand(0:2, M, M)) ./ 2) - # f8 ร— f8 accumulates in f16, then downcasts back to f8. - ref = Float32.(Float8_E4M3FN.(Float16.(Float32.(ah) * Float32.(bh)))) - D = CUDA.zeros(Float32, M, M) - @cuda backend=cuTile blocks=1 mma_e4m3_star(CuArray(ah), CuArray(bh), D) - @test Array(D) == ref - end -end -end - # Block-scaled mma is Blackwell-only (sm_100+), for both FP8 and FP4 operands. @testset "mma_scaled" begin if capability(device()) >= v"10" m, n, k, ks, B = 16, 16, 64, 2, 32 - # 2D ร— 2D for each f8 operand type (scale = f8e8m0fnu, B = 32). - @testset "2D $T" for (T, kern) in ((Float8_E4M3FN, mma_scaled_2d_e4m3), - (Float8_E5M2, mma_scaled_2d_e5m2)) + # MXFP8: f8 operands with an f8e8m0fnu scale, B = 32. 2D ร— 2D, each f8 type. + @testset "MXFP8 2D ($T)" for T in (Float8_E4M3FN, Float8_E5M2) xh = T.(Float32.(rand(0:2, m, k)) ./ 2) # values in {0, 0.5, 1} yh = T.(Float32.(rand(0:2, k, n)) ./ 2) xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) # scales in {1, 2} ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n)) ref = mma_scaled_ref(xh, xs, yh, ys, B) Z = CUDA.zeros(Float32, m, n) - @cuda backend=cuTile blocks=1 kern(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) + @cuda backend=cuTile blocks=1 mma_scaled_2d(CuArray(xh), CuArray(xs), CuArray(yh), CuArray(ys), Z) @test Array(Z) == ref end - @testset "mat-vec" begin + @testset "MXFP8 mat-vec" begin xh = Float8_E4M3FN.(Float32.(rand(0:2, m, k)) ./ 2) yh = Float8_E4M3FN.(Float32.(rand(0:2, k)) ./ 2) xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) @@ -287,7 +327,7 @@ if capability(device()) >= v"10" @test Array(Z) == ref end - @testset "batched (trailing batch broadcast)" begin + @testset "MXFP8 batched (trailing batch broadcast)" begin bt = 2 xh = Float8_E4M3FN.(Float32.(rand(0:2, m, k)) ./ 2) xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) @@ -299,25 +339,32 @@ if capability(device()) >= v"10" @test Array(Z) == ref end - # FP4 operands packed two-per-byte, unpacked + transposed in the kernel - # (see `mma_scaled_fp4`). X packs along K โ†’ bytes (K/2, M); Y packs along - # N โ†’ bytes (N/2, K). - @testset "FP4 (packed, unpack + transpose)" begin + # MXFP4: FP4 operands packed two-per-byte (unpacked + transposed in the + # kernel) with an f8e8m0fnu scale, B = 32 โ†’ K_s = 2. + @testset "MXFP4" begin xh = [rand(F4_VALUES) for _ in 1:m, _ in 1:k] # (M, K) yh = [rand(F4_VALUES) for _ in 1:k, _ in 1:n] # (K, N) - x_bytes = UInt8[f4_nibble(xh[i, 2c-1]) | (f4_nibble(xh[i, 2c]) << 4) - for c in 1:(k รท 2), i in 1:m] # (K/2, M) - y_bytes = UInt8[f4_nibble(yh[kk, 2c-1]) | (f4_nibble(yh[kk, 2c]) << 4) - for c in 1:(n รท 2), kk in 1:k] # (N/2, K) xs = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, m, ks)) ys = Float8_E8M0FNU.(Float32(2) .^ rand(0:1, ks, n)) ref = mma_scaled_ref(xh, xs, yh, ys, B) Z = CUDA.zeros(Float32, m, n) - @cuda backend=cuTile blocks=1 mma_scaled_fp4(CuArray(x_bytes), CuArray(xs), - CuArray(y_bytes), CuArray(ys), Z) + @cuda backend=cuTile blocks=1 mma_scaled_mxfp4(CuArray(pack_f4_bytes(xh)), CuArray(xs), + CuArray(pack_f4_bytes(yh)), CuArray(ys), Z) @test Array(Z) == ref end -end -end + # NVFP4: same FP4 operands but an f8e4m3fn scale at B = 16 โ†’ K_s = 4. + @testset "NVFP4" begin + Bn, ksn = 16, k รท 16 # B = 16 โ†’ K_s = 4 + xh = [rand(F4_VALUES) for _ in 1:m, _ in 1:k] # (M, K) + yh = [rand(F4_VALUES) for _ in 1:k, _ in 1:n] # (K, N) + xs = Float8_E4M3FN.(Float32(2) .^ rand(0:1, m, ksn)) # scales in {1, 2}, exact in e4m3 + ys = Float8_E4M3FN.(Float32(2) .^ rand(0:1, ksn, n)) + ref = mma_scaled_ref(xh, xs, yh, ys, Bn) + Z = CUDA.zeros(Float32, m, n) + @cuda backend=cuTile blocks=1 mma_scaled_nvfp4(CuArray(pack_f4_bytes(xh)), CuArray(xs), + CuArray(pack_f4_bytes(yh)), CuArray(ys), Z) + @test Array(Z) == ref + end +end end From 3e14bf8bfa859e775cb610ca47bcc4a20f7cdf6d Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 3 Jun 2026 16:40:38 +0200 Subject: [PATCH 6/6] Remove muladd `@inline` and `@generated` --- ext/MicrofloatsExt.jl | 5 +- src/language/operations.jl | 136 ++++++++++++++++++------------------- 2 files changed, 67 insertions(+), 74 deletions(-) diff --git a/ext/MicrofloatsExt.jl b/ext/MicrofloatsExt.jl index 09c7b812..e01ed539 100644 --- a/ext/MicrofloatsExt.jl +++ b/ext/MicrofloatsExt.jl @@ -25,10 +25,7 @@ ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero # Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2 # operands with an f16 or f32 accumulator โ€” f16 first/preferred, mirroring -# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. The -# other Microfloats formats (E8M0FNU, Float4_E2M1FN) are only valid as -# scaled-mma operands/scales, so they stay absent and fall through to `mma`'s -# unsupported-dtype error. +# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32) ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32) diff --git a/src/language/operations.jl b/src/language/operations.jl index 3a729b9c..ef6addf8 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -1310,7 +1310,7 @@ carries the result dtype, which must be one tileiras allows for the input dtype `fast_acc` enables fast accumulation (lower accumulator precision for throughput); it is valid only for FP8 inputs and requires Tile IR v13.3+. """ -@inline function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}; +function Base.muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}; fast_acc::Bool=false) where {T1, T2, T3, SA, SB, SC} # SA, SB, SC type parameters avoid ambiguity with the scalar `muladd` # methods during codegen. @@ -1319,18 +1319,18 @@ end # 2D ร— 2D: MmaFOp with swapped operands for row-major Tile IR # Julia (M,K)*(K,N) โ†’ TileIR (K,M)*(N,K) โ†’ mmaf(b,a,acc) โ†’ TileIR (N,M) โ†’ Julia (M,N) -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool) +function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}, fast_acc::Bool) Intrinsics.mma(b, a, acc, fast_acc) end # Vec-mat (1D ร— 2D): reshape (M,) โ†’ (M, 1), MmaFOp, acc is already (M, N) -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool) +function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{1}, ::Val{2}, fast_acc::Bool) a2d = reshape(a, (size(a, 1), 1)) _muladd(a2d, b, acc, Val(2), Val(2), fast_acc) end # Mat-vec (2D ร— 1D): reshape b (K,) โ†’ (K, 1), acc (M,) โ†’ (M, 1), MmaFOp, squeeze back -@inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool) +function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{1}, fast_acc::Bool) M, K = size(a, 1), size(b, 1) b2d = reshape(b, (K, 1)) acc2d = reshape(acc, (M, 1)) @@ -1339,18 +1339,18 @@ end end # Vec-vec (1D ร— 1D): not supported -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool) - return :(throw(ArgumentError("Vector-vector multiply-accumulate is not supported."))) +function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}, ::Bool) + throw(ArgumentError("Vector-vector multiply-accumulate is not supported.")) end # Batched mat-vec / vec-mat (โ‰ฅ3D ร— 1D or 1D ร— โ‰ฅ3D): not supported, unsqueeze manually -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB} - NB >= 3 || return :(throw(ArgumentError("unreachable"))) - return :(throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first."))) +function _muladd(::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}, ::Bool) where {NB} + NB >= 3 || throw(ArgumentError("unreachable")) + throw(ArgumentError("Batched vec-mat is not supported. Reshape the 1D operand to 2D first.")) end -@generated function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA} - NA >= 3 || return :(throw(ArgumentError("unreachable"))) - return :(throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first."))) +function _muladd(::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}, ::Bool) where {NA} + NA >= 3 || throw(ArgumentError("unreachable")) + throw(ArgumentError("Batched mat-vec is not supported. Reshape the 1D operand to 2D first.")) end # Batched matmul (โ‰ฅ3D ร— โ‰ฅ3D): trailing batch dims with broadcast @@ -1360,10 +1360,10 @@ end # 2. Flatten batch dims into one via reshape (no permute needed!) # 3. MmaFOp with swapped operands: mmaf(b, a, acc) # 4. Unflatten batch dims via reshape -@generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}, - ::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB} - sa = Tuple(SA.parameters) - sb = Tuple(SB.parameters) +function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}, + ::Val{NA}, ::Val{NB}, fast_acc::Bool) where {T1, T2, T3, SA, SB, SC, NA, NB} + sa = size(a) + sb = size(b) # Matrix dims are first two; batch dims are trailing M = sa[1]; K = sa[2]; N = sb[2] @@ -1374,24 +1374,22 @@ end n_batch = max(length(a_batch), length(b_batch)) a_batch_padded = (a_batch..., ntuple(Returns(1), n_batch - length(a_batch))...) b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...) - batch_shape = map(max, a_batch_padded, b_batch_padded) + batch_shape = max.(a_batch_padded, b_batch_padded) B_flat = prod(batch_shape) - quote - # Reshape + broadcast to align batch dims (still trailing) - a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...))) - b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...))) - acc_bc = broadcast_to(acc, $((M, N, batch_shape...))) - # Flatten batch dims to one โ€” no permute needed since row-major Tile IR - # already has batch as the leading (slowest) dimension - a_3d = reshape(a_bc, $((M, K, B_flat))) - b_3d = reshape(b_bc, $((K, N, B_flat))) - acc_3d = reshape(acc_bc, $((M, N, B_flat))) - # MmaFOp with swapped operands for row-major convention - result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc) - # Unflatten batch dims - reshape(result_3d, $((M, N, batch_shape...))) - end + # Reshape + broadcast to align batch dims (still trailing) + a_bc = broadcast_to(reshape(a, (M, K, a_batch_padded...)), (M, K, batch_shape...)) + b_bc = broadcast_to(reshape(b, (K, N, b_batch_padded...)), (K, N, batch_shape...)) + acc_bc = broadcast_to(acc, (M, N, batch_shape...)) + # Flatten batch dims to one โ€” no permute needed since row-major Tile IR + # already has batch as the leading (slowest) dimension + a_3d = reshape(a_bc, (M, K, B_flat)) + b_3d = reshape(b_bc, (K, N, B_flat)) + acc_3d = reshape(acc_bc, (M, N, B_flat)) + # MmaFOp with swapped operands for row-major convention + result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d, fast_acc) + # Unflatten batch dims + reshape(result_3d, (M, N, batch_shape...)) end #============================================================================= @@ -1414,14 +1412,14 @@ dimension except K, where they have `K_s โ‰ค K` entries. follow [`muladd`](@ref): 2-D `(M, K)` ร— `(K, N)`, mat-vec, and trailing batch dims; vec-mat is unsupported (it would collapse K, leaving nothing to scale). """ -@inline function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile, +function muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile, b::Tile{Tb, SB}, b_scale::Tile, acc::Tile) where {Ta, Tb, SA, SB} _muladd_scaled(a, a_scale, b, b_scale, acc, Val(ndims(a)), Val(ndims(b))) end # 2D ร— 2D: swap operands (and their scales) for row-major Tile IR, exactly as # `_muladd` swaps for `mma`. -@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, +function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, ::Val{2}, ::Val{2}) Intrinsics.mma_scaled(b, b_scale, a, a_scale, acc) end @@ -1429,7 +1427,7 @@ end # Mat-vec (2D ร— 1D): the K-vector `b` (and its scale) gain a trailing N=1 dim; # `acc` becomes (M, 1); then squeeze back to (M,). K โ€” the scaled dimension โ€” # is preserved, so block scaling is well-defined. -@inline function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, +function _muladd_scaled(a::Tile, a_scale::Tile, b::Tile, b_scale::Tile, acc::Tile, ::Val{2}, ::Val{1}) M, K, Ks = size(a, 1), size(b, 1), size(b_scale, 1) b2d = reshape(b, (K, 1)) @@ -1442,33 +1440,33 @@ end # Vec-mat (1D ร— 2D): promoting `a` to (M, 1) collapses K to 1, leaving no K # dimension to block-scale. Unsupported โ€” reshape to 2D and supply a matching # K_s scale instead. -@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2}) - return :(throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled)."))) +function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{2}) + throw(ArgumentError("Scaled vec-mat is not supported (the K dimension collapses to 1, which cannot be block-scaled).")) end # Vec-vec (1D ร— 1D): not supported. -@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) - return :(throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported."))) +function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{1}) + throw(ArgumentError("Scaled vector-vector multiply-accumulate is not supported.")) end # Batched mat-vec / vec-mat (โ‰ฅ3D ร— 1D or 1D ร— โ‰ฅ3D): not supported. -@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} - return :(throw(ArgumentError("Batched scaled vec-mat is not supported."))) +function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{1}, ::Val{NB}) where {NB} + throw(ArgumentError("Batched scaled vec-mat is not supported.")) end -@generated function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} - return :(throw(ArgumentError("Batched scaled mat-vec is not supported."))) +function _muladd_scaled(::Tile, ::Tile, ::Tile, ::Tile, ::Tile, ::Val{NA}, ::Val{1}) where {NA} + throw(ArgumentError("Batched scaled mat-vec is not supported.")) end # Batched (โ‰ฅ3D ร— โ‰ฅ3D): trailing batch dims with broadcast, mirroring `_muladd`. # Scales carry the same batch dims as their operands; a_scale's batch must match # a's batch (likewise b_scale/b), then both broadcast to the common batch shape. -@generated function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS}, - b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS}, - acc::Tile{Tc, SC}, - ::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc, - SA, SAS, SB, SBS, SC, NA, NB} - sa = Tuple(SA.parameters); sas = Tuple(SAS.parameters) - sb = Tuple(SB.parameters); sbs = Tuple(SBS.parameters) +function _muladd_scaled(a::Tile{Ta, SA}, a_scale::Tile{Tas, SAS}, + b::Tile{Tb, SB}, b_scale::Tile{Tbs, SBS}, + acc::Tile{Tc, SC}, + ::Val{NA}, ::Val{NB}) where {Ta, Tas, Tb, Tbs, Tc, + SA, SAS, SB, SBS, SC, NA, NB} + sa = size(a); sas = size(a_scale) + sb = size(b); sbs = size(b_scale) # Matrix dims are first two; batch dims are trailing. M = sa[1]; K = sa[2]; N = sb[2] @@ -1483,28 +1481,26 @@ end b_batch_padded = (b_batch..., ntuple(Returns(1), n_batch - length(b_batch))...) as_batch_padded = (as_batch..., ntuple(Returns(1), n_batch - length(as_batch))...) bs_batch_padded = (bs_batch..., ntuple(Returns(1), n_batch - length(bs_batch))...) - batch_shape = map(max, a_batch_padded, b_batch_padded) + batch_shape = max.(a_batch_padded, b_batch_padded) B_flat = prod(batch_shape) - quote - # Reshape + broadcast to align batch dims (still trailing). - a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...))) - b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...))) - as_bc = broadcast_to(reshape(a_scale, $((M, Ksa, as_batch_padded...))), $((M, Ksa, batch_shape...))) - bs_bc = broadcast_to(reshape(b_scale, $((Ksb, N, bs_batch_padded...))), $((Ksb, N, batch_shape...))) - acc_bc = broadcast_to(acc, $((M, N, batch_shape...))) - # Flatten batch dims to one โ€” no permute needed since row-major Tile IR - # already has batch as the leading (slowest) dimension. - a_3d = reshape(a_bc, $((M, K, B_flat))) - b_3d = reshape(b_bc, $((K, N, B_flat))) - as_3d = reshape(as_bc, $((M, Ksa, B_flat))) - bs_3d = reshape(bs_bc, $((Ksb, N, B_flat))) - acc_3d = reshape(acc_bc, $((M, N, B_flat))) - # mmaf_scaled with swapped operands for row-major convention. - result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d) - # Unflatten batch dims. - reshape(result_3d, $((M, N, batch_shape...))) - end + # Reshape + broadcast to align batch dims (still trailing). + a_bc = broadcast_to(reshape(a, (M, K, a_batch_padded...)), (M, K, batch_shape...)) + b_bc = broadcast_to(reshape(b, (K, N, b_batch_padded...)), (K, N, batch_shape...)) + as_bc = broadcast_to(reshape(a_scale, (M, Ksa, as_batch_padded...)), (M, Ksa, batch_shape...)) + bs_bc = broadcast_to(reshape(b_scale, (Ksb, N, bs_batch_padded...)), (Ksb, N, batch_shape...)) + acc_bc = broadcast_to(acc, (M, N, batch_shape...)) + # Flatten batch dims to one โ€” no permute needed since row-major Tile IR + # already has batch as the leading (slowest) dimension. + a_3d = reshape(a_bc, (M, K, B_flat)) + b_3d = reshape(b_bc, (K, N, B_flat)) + as_3d = reshape(as_bc, (M, Ksa, B_flat)) + bs_3d = reshape(bs_bc, (Ksb, N, B_flat)) + acc_3d = reshape(acc_bc, (M, N, B_flat)) + # mmaf_scaled with swapped operands for row-major convention. + result_3d = Intrinsics.mma_scaled(b_3d, bs_3d, a_3d, as_3d, acc_3d) + # Unflatten batch dims. + reshape(result_3d, (M, N, batch_shape...)) end # Matrix multiplication: A * B = muladd(A, B, zeros)