Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ julia> Pkg.add("cuTile")
```

Execution of cuTile kernels requires CUDA.jl to be installed and imported.
cuTile generates kernels based on [Tile IR](https://docs.nvidia.com/cuda/tile-ir/), which requires an NVIDIA Driver that supports CUDA 13 (580 or later).
cuTile generates kernels based on [Tile IR](https://docs.nvidia.com/cuda/tile-ir/), which requires an NVIDIA Driver that supports CUDA 13 (580 or later),
and runs on GPUs with a [Compute Capability (CC)](https://developer.nvidia.com/cuda/gpus) of at least 8.0 (Ampere).
CUDA.jl automatically downloads the appropriate CUDA toolkit artifacts, so no manual CUDA installation is needed.
Only Ampere, Ada, and Blackwell GPUs are supported at this time, with Hopper support expected
in a future release of CUDA.

## Quick Start

Expand Down Expand Up @@ -120,13 +119,14 @@ uses standard Julia syntax and is overlaid on `Base`.

### Supported Types

**Integers:** `Int8`, `UInt8`, `Int16`, `UInt16`, `Int32`, `UInt32`, `Int64`, `UInt64`
**Floats:** `Float16`, `BFloat16`, `Float32`, `Float64`, `TFloat32`
**FP8:** `Float8_E4M3FN`, `Float8_E5M2` (requires [DLFP8Types.jl](https://github.com/JuliaGPU/DLFP8Types.jl))
**Boolean:** `Bool`
- **Integers:** `Int8`, `UInt8`, `Int16`, `UInt16`, `Int32`, `UInt32`, `Int64`, `UInt64`
- **Boolean:** `Bool`
- **Arithmetic Floats:** `Float16`, `BFloat16`, `Float32`, `Float64`
- **Numeric Floats:** `TFloat32`\*, `Float8_E4M3FN`\*\*, `Float8_E5M2`\*\*, `Float8_E8M0FNU`\*\*, `Float4_E2M1FN`\*\*

`TFloat32` is a 32-bit floating-point type with reduced mantissa precision (10 bits),
optimized for tensor core operations.
\* `cuTile.TFloat32` is a public 32-bit floating-point numeric type with truncated mantissa (10 bits), made for tensor core operations.

\*\* [Microscaling (MX)](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) numeric types, exported by [Microfloats.jl](https://github.com/MurrellGroup/Microfloats.jl). `Float8_E4M3FN` and `Float8_E5M2` (FP8) are also exported by [DLFP8Types.jl](https://github.com/JuliaGPU/DLFP8Types.jl).

### Memory
| Operation | Description |
Expand Down
43 changes: 28 additions & 15 deletions src/launch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,33 @@ function probe_max_bytecode_version()
"($(join(reverse(SUPPORTED_BYTECODE_VERSIONS), ", "))); last log:\n$last_log")
end

"""
tile_ir_requirement(cap::VersionNumber) -> Union{Tuple{String,VersionNumber}, Nothing}

The architecture-family name and the minimum bytecode version Tile IR requires
on a device of compute capability `cap`, or `nothing` if Tile IR is not
supported on that capability at all. Pure (no device access) so the gate logic
in [`check_tile_ir_support`] can be unit-tested without a GPU.
"""
function tile_ir_requirement(cap::VersionNumber)
if cap >= v"10.0" # Blackwell
return ("Blackwell", v"13.1")
elseif cap >= v"9.0" # Hopper
return ("Hopper", v"13.3")
elseif cap >= v"8.0" # Ampere / Ada
return ("Ampere/Ada", v"13.2")
else
return nothing
end
end

"""
check_tile_ir_support()

Validate that the current `tileiras` toolkit supports Tile IR on the active
device. Returns the bytecode version cuTile should emit for this device
(per [`bytecode_version`]), provided it meets the device's minimum
requirement (Blackwell ≥ v13.1, Ampere/Ada ≥ v13.2).
requirement (Blackwell ≥ v13.1, Hopper ≥ v13.3, Ampere/Ada ≥ v13.2).
"""
function check_tile_ir_support()
if tileiras_override === nothing && !CUDA_Compiler_jll.is_available()
Expand All @@ -310,23 +330,16 @@ function check_tile_ir_support()

cap = capability(dev)
sm_str = format_sm_arch(cap)
if cap >= v"10.0" # Blackwell
if ver < v"13.1"
@error "Tile IR on Blackwell ($sm_str) requires bytecode ≥ v13.1, detected v$ver"
return nothing
end
elseif cap >= v"9.0" # Hopper — not supported
@error "Tile IR is not supported on Hopper ($sm_str)"
return nothing
elseif cap >= v"8.0" # Ampere / Ada
if ver < v"13.2"
@error "Tile IR on Ampere/Ada ($sm_str) requires bytecode ≥ v13.2, detected v$ver"
return nothing
end
else
req = tile_ir_requirement(cap)
if req === nothing
@error "Tile IR is not supported on compute capability $cap ($sm_str)"
return nothing
end
arch, min_ver = req
if ver < min_ver
@error "Tile IR on $arch ($sm_str) requires bytecode ≥ v$min_ver, detected v$ver"
return nothing
end

return ver
end
Expand Down
14 changes: 14 additions & 0 deletions test/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ end
@test_throws ArgumentError cuTile.format_sm_arch(v"10.0.1")
end

@testset "tile_ir_requirement" begin
# Blackwell (sm_100+) requires bytecode ≥ v13.1
@test cuTile.tile_ir_requirement(v"10.0") == ("Blackwell", v"13.1")
@test cuTile.tile_ir_requirement(v"12.1") == ("Blackwell", v"13.1")
# Hopper (sm_90) requires bytecode ≥ v13.3
@test cuTile.tile_ir_requirement(v"9.0") == ("Hopper", v"13.3")
# Ampere / Ada (sm_80..sm_89) requires bytecode ≥ v13.2
@test cuTile.tile_ir_requirement(v"8.0") == ("Ampere/Ada", v"13.2")
@test cuTile.tile_ir_requirement(v"8.9") == ("Ampere/Ada", v"13.2")
# Pre-Ampere is unsupported
@test cuTile.tile_ir_requirement(v"7.5") === nothing
@test cuTile.tile_ir_requirement(v"7.0") === nothing
end

@testset "@compiler_options validation" begin
# Invalid num_ctas (not power of 2) should throw at definition time
@test_throws "num_ctas must be" @eval function _test_bad_ctas(a::ct.TileArray{Float32,1})
Expand Down