Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,23 @@ Tile IR operations.
| Operation | Description |
|-----------|-------------|
| `a * b` | Matrix multiplication: `a @ b` |
| `muladd(a, b, acc)` | Matrix multiply-accumulate: `a * b + acc` |
| `muladd(a, b, acc; fast_acc=false)` | Matrix multiply-accumulate: `a * b + acc` |
| `ct.muladd_scaled(a, a_scale, b, b_scale, acc)` | Block-scaled multiply-accumulate |

Each operation follows `Base.:*` / `Base.muladd`'s shape rules, with the addition of allowing trailing batch dimensions.

`fast_acc=true` enables fast accumulation for FP8 inputs, and has an effect only on Hopper (sm_90; silently ignored on other
architectures), and requires Tile IR v13.3+.

`ct.muladd_scaled` multiplies each operand by a low-precision block scale before the matmul:
each scale element covers a contiguous block of `B = K ÷ K_s` elements along the K dimension.
Requires Blackwell. The supported operand/scale/accumulator dtypes and block sizes are:

| Input (`a`/`b`) | Scale | Acc/Output | B |
|-----------------|-------|------------|--------|
| `Float8_E4M3FN`, `Float8_E5M2` | `Float8_E8M0FNU` | `Float32` | 32 |
| `Float4_E2M1FN` | `Float8_E8M0FNU` | `Float32` | 16, 32 |
| `Float4_E2M1FN` | `Float8_E4M3FN` | `Float32` | 16 |

### Higher-Order Functions
| Operation | Description |
Expand Down
10 changes: 10 additions & 0 deletions ext/DLFP8TypesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2})
return ct.F8E5M2(table)
end

# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2
# operands with an f16 or f32 accumulator (f16 first/preferred), mirroring
# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`.
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)

# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint.
ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true
ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true

# Float ↔ FP8 scalar constructor overlays (for map/convert dispatch)
const FP8Types = (Float8_E4M3FN, Float8_E5M2)
const StandardFloats = (Float16, ct.BFloat16, Float32, ct.TFloat32, Float64)
Expand Down
13 changes: 13 additions & 0 deletions ext/MicrofloatsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ ct.bitwidth(::Type{T}) where {T<:Microfloats.Microfloat} = Microfloats.bitwidth(
# nearest-even on f32→E8M0FNU (only `zero` and `positive_inf` are valid).
ct.ftof_rounding_mode(::Type{Float8_E8M0FNU}) = ct.RoundingMode.Zero

# Non-scaled `mma`/`matmul` (`cuda_tile.mmaf`) accepts f8e4m3fn and f8e5m2
# operands with an f16 or f32 accumulator — f16 first/preferred, mirroring
# cuda-tile's mmaf type table and cutile-python's `_mma_supported_dtypes`. The
# other Microfloats formats (E8M0FNU, Float4_E2M1FN) are only valid as
# scaled-mma operands/scales, so they stay absent and fall through to `mma`'s
# unsupported-dtype error.
ct.mma_allowed_acc_dtypes(::Type{Float8_E4M3FN}) = (Float16, Float32)
ct.mma_allowed_acc_dtypes(::Type{Float8_E5M2}) = (Float16, Float32)

# `fast_acc` (lower-precision MMA accumulation) is an FP8-only throughput hint.
ct.mma_supports_fast_acc(::Type{Float8_E4M3FN}) = true
ct.mma_supports_fast_acc(::Type{Float8_E5M2}) = true

# Float ↔ microfloat scalar constructor overlays (for map/convert dispatch).
# Mirrors DLFP8TypesExt: route to `Intrinsics.ftof` so kernel-side conversions
# lower to the FToFOp Tile IR intrinsic instead of Microfloats' Float32-fallback
Expand Down
26 changes: 26 additions & 0 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ module Opcode
const Atan2Op = 110 # since 13.2
const PackOp = 111 # since 13.3
const UnpackOp = 112 # since 13.3
# 113 (AllocaOp) not implemented
const MmaFScaledOp = 114 # since 13.3
end

# Enums for operation attributes
Expand Down Expand Up @@ -1310,6 +1312,30 @@ function encode_MmaIOp!(cb::CodeBuilder, result_type::TypeId,
return new_op!(cb)
end

"""
encode_MmaFScaledOp!(cb, result_type, lhs, rhs, acc, lhs_scale, rhs_scale) -> Value

Block-scaled float matrix multiply-accumulate (`acc + (lhs ⊙ lhs_scale) @
(rhs ⊙ rhs_scale)`) for low-precision (f8/f4) inputs. Each scale element scales
a contiguous block of K-dimension elements in its operand. Requires Tile IR
v13.3+.
Opcode: 114
"""
function encode_MmaFScaledOp!(cb::CodeBuilder, result_type::TypeId,
lhs::Value, rhs::Value, acc::Value,
lhs_scale::Value, rhs_scale::Value)
cb.version >= v"13.3" ||
throw(IRError("MmaFScaledOp requires Tile IR v13.3+, got v$(cb.version)"))
encode_varint!(cb.buf, Opcode.MmaFScaledOp)
encode_typeid!(cb.buf, result_type)
encode_operand!(cb.buf, lhs)
encode_operand!(cb.buf, rhs)
encode_operand!(cb.buf, acc)
encode_operand!(cb.buf, lhs_scale)
encode_operand!(cb.buf, rhs_scale)
return new_op!(cb)
end

#=============================================================================
Integer arithmetic operations
=============================================================================#
Expand Down
89 changes: 84 additions & 5 deletions src/compiler/intrinsics/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,17 @@ mma_allowed_acc_dtypes(::Type{TFloat32}) = (Float32,)
mma_allowed_acc_dtypes(::Type{Float64}) = (Float64,)
mma_allowed_acc_dtypes(@nospecialize(::Type)) = nothing

"""
mma_supports_fast_acc(input_T) -> Bool

Whether `mma`'s `fast_acc` hint (lower-precision accumulation for throughput)
is valid for operand element type `input_T`. Only the FP8 dtypes qualify;
extensions (Microfloats/DLFP8Types) register them by overloading this, like
[`mma_allowed_acc_dtypes`](@ref). Mirrors cuTile Python's
`use_fast_acc is only supported for fp8 input dtypes` check.
"""
mma_supports_fast_acc(@nospecialize(::Type)) = false

# First (preferred) accumulator dtype for a given input dtype. Used by
# matmul to pick `acc = zeros(first_allowed_acc(T), …)` so that the input
# dtype constraint and tileiras's acc-dtype constraint stay consistent
Expand All @@ -425,7 +436,7 @@ first_allowed_acc_dtype(::Type{T}) where {T} =
end

"""
Intrinsics.mma(a::Tile, b::Tile, acc::Tile) -> typeof(acc)
Intrinsics.mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false) -> typeof(acc)

Matrix-multiply-accumulate computing `a*b + acc`. Dispatches at codegen
based on element types:
Expand All @@ -436,15 +447,23 @@ based on element types:
- `i8` `a`/`b` with `i32` `acc` lower to `cuda_tile.mmai`. Per-input
signedness is derived from the Julia type (`Int8` → signed,
`UInt8` → unsigned); `acc` and the result are always signed `i32`.

`fast_acc` enables fast accumulation (trading accumulator precision for
throughput). It is only valid for FP8 inputs (see
[`mma_supports_fast_acc`](@ref)) and requires Tile IR v13.3+.
"""
@intrinsic mma(a::Tile, b::Tile, acc::Tile)
tfunc(𝕃, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc)) = CC.widenconst(acc)
@intrinsic mma(a::Tile, b::Tile, acc::Tile, fast_acc::Bool=false)
tfunc(𝕃, ::typeof(Intrinsics.mma), @nospecialize(a), @nospecialize(b), @nospecialize(acc),
@nospecialize(rest...)) = CC.widenconst(acc)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
cb = ctx.cb

lhs = emit_value!(ctx, args[1])
rhs = emit_value!(ctx, args[2])
acc = emit_value!(ctx, args[3])
fast_acc = length(args) >= 4 ?
(@something get_constant(ctx, args[4]) throw(IRError("mma: fast_acc must be a compile-time constant"))) :
false

(lhs === nothing || rhs === nothing || acc === nothing) && throw(IRError("Cannot resolve operands for mma()"))

Expand All @@ -458,15 +477,24 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
# the table in cuTile Python's mma implementation.
lhs_elem === rhs_elem ||
throw(IRError("mma: float lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem"))
allowed_acc = mma_allowed_acc_dtypes(lhs_elem)
# `invokelatest` so extension-defined acc-dtype tables (e.g. FP8 via the
# Microfloats/DLFP8Types exts) are visible from codegen's world age,
# mirroring `lookup_bitwidth`.
allowed_acc = Base.invokelatest(mma_allowed_acc_dtypes, lhs_elem)
allowed_acc === nothing &&
throw(IRError("mma: unsupported float input dtype $lhs_elem"))
acc_elem in allowed_acc ||
throw(IRError("mma: acc dtype $acc_elem is not allowed for input dtype " *
"$lhs_elem; tileiras requires acc ∈ $allowed_acc"))
encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v)
# fast_acc is an FP8-only throughput hint; reject it elsewhere with a
# clear error rather than emitting IR tileiras would reject.
fast_acc && !Base.invokelatest(mma_supports_fast_acc, lhs_elem) &&
throw(IRError("mma: fast_acc is only supported for fp8 input dtypes " *
"(f8e4m3fn, f8e5m2), got $lhs_elem"))
encode_MmaFOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v; fast_acc)
elseif lhs_elem <: Union{Int8, UInt8} && rhs_elem <: Union{Int8, UInt8} &&
acc_elem === Int32
fast_acc && throw(IRError("mma: fast_acc is not supported for integer (mmai) inputs"))
s_lhs = lhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned
s_rhs = rhs_elem <: Signed ? Signedness.Signed : Signedness.Unsigned
encode_MmaIOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v;
Expand All @@ -480,6 +508,57 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma), args)
CGVal(result, acc.type_id, acc.jltype, acc.shape)
end

"""
Intrinsics.mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc) -> typeof(acc)

Block-scaled matrix-multiply-accumulate computing `(lhs ⊙ lhs_scale) * (rhs ⊙
rhs_scale) + acc`, where each scale element multiplies a contiguous block of
`lhs`/`rhs` elements along the K dimension. Lowers to `cuda_tile.mmaf_scaled`
(Tile IR v13.3+).

`lhs`/`rhs` are low-precision floats (`f8e4m3fn`, `f8e5m2`, or `f4e2m1fn`),
`lhs_scale`/`rhs_scale` are `f8e8m0fnu` or `f8e4m3fn`, and `acc`/result are
`f32`. The block size `K ÷ K_s` and the exact (operand, scale) dtype pairing are
validated by tileiras (see its `mmaf_scaled` verifier for the supported table).
"""
@intrinsic mma_scaled(lhs, lhs_scale, rhs, rhs_scale, acc)
tfunc(𝕃, ::typeof(Intrinsics.mma_scaled), @nospecialize(lhs), @nospecialize(lhs_scale),
@nospecialize(rhs), @nospecialize(rhs_scale), @nospecialize(acc)) = CC.widenconst(acc)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.mma_scaled), args)
cb = ctx.cb

lhs = emit_value!(ctx, args[1])
lhs_scale = emit_value!(ctx, args[2])
rhs = emit_value!(ctx, args[3])
rhs_scale = emit_value!(ctx, args[4])
acc = emit_value!(ctx, args[5])

(lhs === nothing || lhs_scale === nothing || rhs === nothing ||
rhs_scale === nothing || acc === nothing) &&
throw(IRError("Cannot resolve operands for mma_scaled()"))

lhs_elem = eltype(CC.widenconst(lhs.jltype))
rhs_elem = eltype(CC.widenconst(rhs.jltype))
lhs_scale_elem = eltype(CC.widenconst(lhs_scale.jltype))
rhs_scale_elem = eltype(CC.widenconst(rhs_scale.jltype))
acc_elem = eltype(CC.widenconst(acc.jltype))

# Structural invariants (cheap, with clear errors); the operand/scale dtype
# pairing and block-size constraints are checked by the tileiras verifier,
# whose messages are already precise.
lhs_elem === rhs_elem ||
throw(IRError("mma_scaled: lhs and rhs must share dtype, got lhs=$lhs_elem, rhs=$rhs_elem"))
lhs_scale_elem === rhs_scale_elem ||
throw(IRError("mma_scaled: lhs_scale and rhs_scale must share dtype, " *
"got $lhs_scale_elem and $rhs_scale_elem"))
acc_elem === Float32 ||
throw(IRError("mma_scaled: acc must be Float32, got $acc_elem"))

result = encode_MmaFScaledOp!(cb, acc.type_id, lhs.v, rhs.v, acc.v,
lhs_scale.v, rhs_scale.v)
CGVal(result, acc.type_id, acc.jltype, acc.shape)
end

# TODO: cuda_tile.module

"""
Expand Down
Loading