diff --git a/README.md b/README.md index 40a5924c..a753db39 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,9 @@ julia> Pkg.add("cuTile") ``` Execution of cuTile kernels requires CUDA.jl to be installed and imported. -cuTile generates kernels based on [Tile IR](https://docs.nvidia.com/cuda/tile-ir/), which requires an NVIDIA Driver that supports CUDA 13 (580 or later). +cuTile generates kernels based on [Tile IR](https://docs.nvidia.com/cuda/tile-ir/), which requires an NVIDIA Driver that supports CUDA 13 (580 or later), +and runs on GPUs with a [Compute Capability (CC)](https://developer.nvidia.com/cuda/gpus) of at least 8.0 (Ampere). CUDA.jl automatically downloads the appropriate CUDA toolkit artifacts, so no manual CUDA installation is needed. -Only Ampere, Ada, and Blackwell GPUs are supported at this time, with Hopper support expected -in a future release of CUDA. ## Quick Start @@ -120,13 +119,14 @@ uses standard Julia syntax and is overlaid on `Base`. ### Supported Types -**Integers:** `Int8`, `UInt8`, `Int16`, `UInt16`, `Int32`, `UInt32`, `Int64`, `UInt64` -**Floats:** `Float16`, `BFloat16`, `Float32`, `Float64`, `TFloat32` -**FP8:** `Float8_E4M3FN`, `Float8_E5M2` (requires [DLFP8Types.jl](https://github.com/JuliaGPU/DLFP8Types.jl)) -**Boolean:** `Bool` +- **Integers:** `Int8`, `UInt8`, `Int16`, `UInt16`, `Int32`, `UInt32`, `Int64`, `UInt64` +- **Boolean:** `Bool` +- **Arithmetic Floats:** `Float16`, `BFloat16`, `Float32`, `Float64` +- **Numeric Floats:** `TFloat32`\*, `Float8_E4M3FN`\*\*, `Float8_E5M2`\*\*, `Float8_E8M0FNU`\*\*, `Float4_E2M1FN`\*\* -`TFloat32` is a 32-bit floating-point type with reduced mantissa precision (10 bits), -optimized for tensor core operations. +\* `cuTile.TFloat32` is a public 32-bit floating-point numeric type with truncated mantissa (10 bits), made for tensor core operations. + +\*\* [Microscaling (MX)](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) numeric types, exported by [Microfloats.jl](https://github.com/MurrellGroup/Microfloats.jl). `Float8_E4M3FN` and `Float8_E5M2` (FP8) are also exported by [DLFP8Types.jl](https://github.com/JuliaGPU/DLFP8Types.jl). ### Memory | Operation | Description | diff --git a/src/launch.jl b/src/launch.jl index 3b69dcee..36a53b7a 100644 --- a/src/launch.jl +++ b/src/launch.jl @@ -290,13 +290,33 @@ function probe_max_bytecode_version() "($(join(reverse(SUPPORTED_BYTECODE_VERSIONS), ", "))); last log:\n$last_log") end +""" + tile_ir_requirement(cap::VersionNumber) -> Union{Tuple{String,VersionNumber}, Nothing} + +The architecture-family name and the minimum bytecode version Tile IR requires +on a device of compute capability `cap`, or `nothing` if Tile IR is not +supported on that capability at all. Pure (no device access) so the gate logic +in [`check_tile_ir_support`] can be unit-tested without a GPU. +""" +function tile_ir_requirement(cap::VersionNumber) + if cap >= v"10.0" # Blackwell + return ("Blackwell", v"13.1") + elseif cap >= v"9.0" # Hopper + return ("Hopper", v"13.3") + elseif cap >= v"8.0" # Ampere / Ada + return ("Ampere/Ada", v"13.2") + else + return nothing + end +end + """ check_tile_ir_support() Validate that the current `tileiras` toolkit supports Tile IR on the active device. Returns the bytecode version cuTile should emit for this device (per [`bytecode_version`]), provided it meets the device's minimum -requirement (Blackwell ≥ v13.1, Ampere/Ada ≥ v13.2). +requirement (Blackwell ≥ v13.1, Hopper ≥ v13.3, Ampere/Ada ≥ v13.2). """ function check_tile_ir_support() if tileiras_override === nothing && !CUDA_Compiler_jll.is_available() @@ -310,23 +330,16 @@ function check_tile_ir_support() cap = capability(dev) sm_str = format_sm_arch(cap) - if cap >= v"10.0" # Blackwell - if ver < v"13.1" - @error "Tile IR on Blackwell ($sm_str) requires bytecode ≥ v13.1, detected v$ver" - return nothing - end - elseif cap >= v"9.0" # Hopper — not supported - @error "Tile IR is not supported on Hopper ($sm_str)" - return nothing - elseif cap >= v"8.0" # Ampere / Ada - if ver < v"13.2" - @error "Tile IR on Ampere/Ada ($sm_str) requires bytecode ≥ v13.2, detected v$ver" - return nothing - end - else + req = tile_ir_requirement(cap) + if req === nothing @error "Tile IR is not supported on compute capability $cap ($sm_str)" return nothing end + arch, min_ver = req + if ver < min_ver + @error "Tile IR on $arch ($sm_str) requires bytecode ≥ v$min_ver, detected v$ver" + return nothing + end return ver end diff --git a/test/types.jl b/test/types.jl index 56b5e2e0..43d8670d 100644 --- a/test/types.jl +++ b/test/types.jl @@ -182,6 +182,20 @@ end @test_throws ArgumentError cuTile.format_sm_arch(v"10.0.1") end +@testset "tile_ir_requirement" begin + # Blackwell (sm_100+) requires bytecode ≥ v13.1 + @test cuTile.tile_ir_requirement(v"10.0") == ("Blackwell", v"13.1") + @test cuTile.tile_ir_requirement(v"12.1") == ("Blackwell", v"13.1") + # Hopper (sm_90) requires bytecode ≥ v13.3 + @test cuTile.tile_ir_requirement(v"9.0") == ("Hopper", v"13.3") + # Ampere / Ada (sm_80..sm_89) requires bytecode ≥ v13.2 + @test cuTile.tile_ir_requirement(v"8.0") == ("Ampere/Ada", v"13.2") + @test cuTile.tile_ir_requirement(v"8.9") == ("Ampere/Ada", v"13.2") + # Pre-Ampere is unsupported + @test cuTile.tile_ir_requirement(v"7.5") === nothing + @test cuTile.tile_ir_requirement(v"7.0") === nothing +end + @testset "@compiler_options validation" begin # Invalid num_ctas (not power of 2) should throw at definition time @test_throws "num_ctas must be" @eval function _test_bad_ctas(a::ct.TileArray{Float32,1})