diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e17c4607c79..5748179bf337 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1230,6 +1230,32 @@ if(VLLM_GPU_LANG STREQUAL "HIP") "csrc/rocm/skinny_gemms_w8a8/instantiate_n5.cu" "csrc/rocm/attention.cu") + # AIESW-32176: include CK W4A16 b_scale GEMM source if both: + # -DVLLM_CK_INCLUDE_DIR=/path/to/composablekernel/include (source headers, has ck/ck.hpp) + # -DVLLM_CK_BUILD_INCLUDE_DIR=/path/to/ck-build/include (generated headers, has ck/config.h) + # are set. CK's config.h is generated by CK's own CMake configure (per-build flags), + # so we need both the source tree and the build tree on the include path. + # The dispatcher in hybrid_w4a16.py routes only the tuned gfx1151 Qwen3-4B gate_up_proj + # prefill shape (M=3968, N=19456, K=2560, group=128) to this op; all other shapes + # stay on the existing Triton path. + if(VLLM_CK_INCLUDE_DIR AND VLLM_CK_BUILD_INCLUDE_DIR) + if(NOT EXISTS "${VLLM_CK_INCLUDE_DIR}/ck/ck.hpp") + message(FATAL_ERROR + "VLLM_CK_INCLUDE_DIR=${VLLM_CK_INCLUDE_DIR} does not contain ck/ck.hpp") + endif() + if(NOT EXISTS "${VLLM_CK_BUILD_INCLUDE_DIR}/ck/config.h") + message(FATAL_ERROR + "VLLM_CK_BUILD_INCLUDE_DIR=${VLLM_CK_BUILD_INCLUDE_DIR} does not contain ck/config.h " + "(this is generated by CK's CMake configure step; point at //include)") + endif() + list(APPEND VLLM_ROCM_EXT_SRC "csrc/rocm/ck_w4a16.cu") + message(STATUS "AIESW-32176: building csrc/rocm/ck_w4a16.cu against CK source ${VLLM_CK_INCLUDE_DIR} + build ${VLLM_CK_BUILD_INCLUDE_DIR}") + elseif(VLLM_CK_INCLUDE_DIR OR VLLM_CK_BUILD_INCLUDE_DIR) + message(FATAL_ERROR "AIESW-32176: must set BOTH VLLM_CK_INCLUDE_DIR and VLLM_CK_BUILD_INCLUDE_DIR (or neither)") + else() + message(STATUS "AIESW-32176: VLLM_CK_*_INCLUDE_DIR not set; csrc/rocm/ck_w4a16.cu skipped (CK W4A16 dispatch will fall through to Triton)") + endif() + set(VLLM_ROCM_FLAGS ${VLLM_GPU_FLAGS}) if(VLLM_SKINNY_GEMM_SWEEP OR (DEFINED ENV{VLLM_SKINNY_GEMM_SWEEP} AND NOT $ENV{VLLM_SKINNY_GEMM_SWEEP} STREQUAL "0")) @@ -1253,6 +1279,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP") if(_ROCM_SWEEP_ENABLED) target_compile_definitions(_rocm_C PRIVATE VLLM_SKINNY_GEMM_SWEEP=1) endif() + + if(VLLM_CK_INCLUDE_DIR AND VLLM_CK_BUILD_INCLUDE_DIR) + target_include_directories(_rocm_C PRIVATE ${VLLM_CK_INCLUDE_DIR} ${VLLM_CK_BUILD_INCLUDE_DIR}) + target_compile_definitions(_rocm_C PRIVATE VLLM_HAVE_CK_W4A16=1) + endif() endif() # For CUDA and HIP builds also build the triton_kernels external package. diff --git a/benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py b/benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py index 8336c1e7c94d..9969fd086874 100644 --- a/benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py +++ b/benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py @@ -43,7 +43,8 @@ def prepare_hybrid_weights(K, N, group_size, device="cuda"): """Create random weights for benchmarking. - Returns (w_q_skinny, w_s_skinny, w_fp16, w_q_skinny_i32). + Returns (w_q_skinny, w_s_skinny, w_s_skinny_bf16, w_fp16, w_bf16, + w_q_skinny_i32, w_zp). """ num_groups = K // group_size @@ -53,52 +54,68 @@ def prepare_hybrid_weights(K, N, group_size, device="cuda"): ) w_q_skinny = w_q_skinny_i32.view(torch.int8).contiguous() w_s_skinny = torch.randn(N, num_groups, dtype=torch.float16, device=device) * 0.01 + w_s_skinny_bf16 = w_s_skinny.to(torch.bfloat16) # Raw per-group zero-points for asymmetric benchmarks w_zp = torch.randint(0, 16, (N, num_groups), dtype=torch.int32, device=device).to( torch.float16 ) - # FP16 baseline for F.linear + # FP16 / BF16 baselines for F.linear w_fp16 = torch.randn(N, K, dtype=torch.float16, device=device) * 0.01 - - return w_q_skinny, w_s_skinny, w_fp16, w_q_skinny_i32, w_zp + w_bf16 = w_fp16.to(torch.bfloat16) + + return ( + w_q_skinny, + w_s_skinny, + w_s_skinny_bf16, + w_fp16, + w_bf16, + w_q_skinny_i32, + w_zp, + ) # --------------------------------------------------------------------------- # Benchmark # --------------------------------------------------------------------------- -PROVIDERS = ["torch-fp16", "hybrid-w4a16", "hybrid-w4a16-zp"] +PROVIDERS = [ + "torch-fp16", + "torch-bf16", + "hybrid-w4a16", + "hybrid-w4a16-bf16", + "hybrid-w4a16-zp", +] @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], - x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 3968, 4096], x_log=False, line_arg="provider", line_vals=PROVIDERS, line_names=PROVIDERS, ylabel="TFLOP/s (larger is better)", - plot_name="FP16 vs Hybrid W4A16", + plot_name="fp16_bf16_vs_hybrid_w4a16", args={}, ) ) def benchmark(batch_size, provider, N, K, group_size, weights): M = batch_size device = "cuda" - dtype = torch.float16 + dtype = torch.bfloat16 if provider.endswith("bf16") else torch.float16 a = torch.randn((M, K), device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] - if provider == "torch-fp16": - w_fp16 = weights["w_fp16"] + if provider in ("torch-fp16", "torch-bf16"): + w = weights["w_fp16" if dtype == torch.float16 else "w_bf16"] ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: torch.nn.functional.linear(a, w_fp16), + lambda: torch.nn.functional.linear(a, w), quantiles=quantiles, ) - elif provider in ("hybrid-w4a16", "hybrid-w4a16-zp"): + elif provider in ("hybrid-w4a16", "hybrid-w4a16-bf16", "hybrid-w4a16-zp"): from vllm.model_executor.kernels.linear.mixed_precision.hybrid_w4a16 import ( _hybrid_w4a16_apply_impl, ) @@ -107,12 +124,13 @@ def benchmark(batch_size, provider, N, K, group_size, weights): w = weights cu_count = num_compute_units() use_zp = provider == "hybrid-w4a16-zp" + scales_key = "w_s_skinny_bf16" if dtype == torch.bfloat16 else "w_s_skinny" def run(): return _hybrid_w4a16_apply_impl( a, w["w_q_skinny"], - w["w_s_skinny"], + w[scales_key], w["w_q_skinny_i32"], w["w_zp"] if use_zp else None, None, # bias @@ -161,14 +179,22 @@ def prepare_shapes(args): print(f"{model}, N={N} K={K}, group_size={group_size}") print(f"{'=' * 70}") - w_q_skinny, w_s_skinny, w_fp16, w_q_skinny_i32, w_zp = prepare_hybrid_weights( - K, N, group_size - ) + ( + w_q_skinny, + w_s_skinny, + w_s_skinny_bf16, + w_fp16, + w_bf16, + w_q_skinny_i32, + w_zp, + ) = prepare_hybrid_weights(K, N, group_size) weights = { "w_q_skinny": w_q_skinny, "w_s_skinny": w_s_skinny, + "w_s_skinny_bf16": w_s_skinny_bf16, "w_fp16": w_fp16, + "w_bf16": w_bf16, "w_q_skinny_i32": w_q_skinny_i32, "w_zp": w_zp, } diff --git a/csrc/rocm/ck_w4a16.cu b/csrc/rocm/ck_w4a16.cu new file mode 100644 index 000000000000..be4715d4bca0 --- /dev/null +++ b/csrc/rocm/ck_w4a16.cu @@ -0,0 +1,230 @@ +// AIESW-32176: CK WMMA W4A16 b_scale GEMM wrapper. +// Tuned for gfx1151 (Strix Halo) at the Qwen3-4B gate_up_proj prefill shape +// (M=3968, N=19456, K=2560, group=128). Out of scope for any other shape; +// the Python dispatcher is responsible for restricting calls to the target. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" + +namespace { + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +// EXP1_FINAL config from Phase 1 sweep (30.0 TFLOPS verified at the target +// shape). See AIInfo memory project_aiesw_32176_phase1_2_results for the full +// sweep table. +static constexpr ck::index_t KPerBlock = 32; + +// clang-format off +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, + CDataType, CDataType, PermuteA, PermuteB>; +// clang-format on + +} // namespace + +// in_a: [M, K] fp16, contiguous (row-major) +// in_b: [K0, N, K1/2] int8 in CK pk_i4_v3 b_scale layout (K0 = +// K/KPerBlock, K1 = KPerBlock) in_s: [N, K/G] fp16, contiguous +// row-major (vLLM HybridW4A16 native layout). +// CK calls this `b1_k_n` shape [K/G, N] with stride +// (K/G, 1), which is exactly a view over [N, K/G] +// row-major storage — no transpose needed. +// Returns out [M, N] fp16, freshly allocated. +torch::Tensor ck_w4a16_b_scale_gemm(const at::Tensor& in_a, + const at::Tensor& in_b, + const at::Tensor& in_s, + int64_t group_size) { + TORCH_CHECK(in_a.is_cuda() && in_b.is_cuda() && in_s.is_cuda(), + "All inputs must be on GPU"); + TORCH_CHECK(in_a.dtype() == at::kHalf, "in_a must be fp16"); + TORCH_CHECK(in_s.dtype() == at::kHalf, "in_s must be fp16"); + TORCH_CHECK(in_a.dim() == 2, "in_a must be 2-D [M, K]"); + TORCH_CHECK(in_b.dim() == 3, + "in_b must be 3-D [K0, N, K1/2] (CK pk_i4 layout)"); + TORCH_CHECK(in_s.dim() == 2, + "in_s must be 2-D [N, K/G] row-major (vLLM HybridW4A16 native " + "scale layout)"); + TORCH_CHECK(group_size == Scale_Block_K, + "group_size must equal CK Scale_Block_K (", Scale_Block_K, ")"); + + const int64_t M = in_a.size(0); + const int64_t K = in_a.size(1); + // CK packs 2 nibbles per int8 in the inner K1/2 dim, plus K0 = K/KPerBlock. + const int64_t K0 = in_b.size(0); + const int64_t N = in_b.size(1); + const int64_t K1_half = in_b.size(2); + TORCH_CHECK(K0 * KPerBlock == K, "K0 * KPerBlock != K (", K0, "*", KPerBlock, + "!=", K, ")"); + TORCH_CHECK(K1_half * 2 == KPerBlock, "in_b last dim must be KPerBlock/2 (", + K1_half, "*2 !=", KPerBlock, ")"); + TORCH_CHECK(in_s.size(0) == N && in_s.size(1) * group_size == K, + "in_s shape must be [N, K/G]; got [", in_s.size(0), ",", + in_s.size(1), "] for K=", K, " N=", N, " G=", group_size); + TORCH_CHECK(in_s.is_contiguous(), + "in_s must be contiguous row-major [N, K/G]"); + + auto out = torch::empty({M, N}, in_a.options()); + + const at::cuda::OptionalCUDAGuard guard(device_of(in_a)); + + // Logical strides per the CK device-op signature: + // ALayout=Row -> StrideA = K, BLayout=Col -> StrideB = K, CLayout=Row -> + // StrideC = N. + // The b_scale device-op consumes the permuted in_b buffer directly + // (PermuteB=true bakes the K-block tiling into how the threads read; the + // logical stride is unchanged). + const ck::index_t StrideA = static_cast(K); + const ck::index_t StrideB = static_cast(K); + const ck::index_t StrideC = static_cast(N); + const ck::index_t Scale_Stride_BN = static_cast(K / group_size); + const ck::index_t KBatch = 1; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( + reinterpret_cast(in_a.data_ptr()), + reinterpret_cast(in_b.data_ptr()), + reinterpret_cast(out.data_ptr()), static_cast(M), + static_cast(N), static_cast(K), StrideA, + StrideB, StrideC, Scale_Stride_BN, + reinterpret_cast(in_s.data_ptr()), KBatch, + PassThrough{}, PassThrough{}, PassThrough{}); + + TORCH_CHECK(gemm.IsSupportedArgument(argument), + "CK W4A16 b_scale device op rejected the argument; ", + "shape (M=", M, ", N=", N, ", K=", K, ", G=", group_size, + ") not supported by this build"); + + StreamConfig stream; + stream.stream_id_ = at::cuda::getCurrentCUDAStream(); + invoker.Run(argument, stream); + + return out; +} + +// AIESW-32176: asymmetric (AWQ) variant. +// Same kernel as ck_w4a16_b_scale_gemm except the dequant uses +// (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale +// The caller precomputes scaled_zp[N, K/G] = (zp - 8) * scale once at weight +// load and passes it as in_scaled_zp. The CK kernel subtracts scaled_zp from +// each dequanted half2 inline (one extra fp16 vector subtract per pack). +// in_scaled_zp shape and stride match in_s exactly. +torch::Tensor ck_w4a16_b_scale_zp_gemm(const at::Tensor& in_a, + const at::Tensor& in_b, + const at::Tensor& in_s, + const at::Tensor& in_scaled_zp, + int64_t group_size) { + TORCH_CHECK(in_a.is_cuda() && in_b.is_cuda() && in_s.is_cuda() && + in_scaled_zp.is_cuda(), + "All inputs must be on GPU"); + TORCH_CHECK(in_a.dtype() == at::kHalf, "in_a must be fp16"); + TORCH_CHECK(in_s.dtype() == at::kHalf, "in_s must be fp16"); + TORCH_CHECK(in_scaled_zp.dtype() == at::kHalf, "in_scaled_zp must be fp16"); + TORCH_CHECK(in_a.dim() == 2, "in_a must be 2-D [M, K]"); + TORCH_CHECK(in_b.dim() == 3, + "in_b must be 3-D [K0, N, K1/2] (CK pk_i4 layout)"); + TORCH_CHECK(in_s.dim() == 2, + "in_s must be 2-D [N, K/G] row-major (vLLM HybridW4A16 native " + "scale layout)"); + TORCH_CHECK(in_scaled_zp.sizes() == in_s.sizes(), + "in_scaled_zp must have the same shape as in_s [N, K/G]"); + TORCH_CHECK(in_scaled_zp.is_contiguous(), + "in_scaled_zp must be contiguous row-major [N, K/G]"); + TORCH_CHECK(group_size == Scale_Block_K, + "group_size must equal CK Scale_Block_K (", Scale_Block_K, ")"); + + const int64_t M = in_a.size(0); + const int64_t K = in_a.size(1); + const int64_t K0 = in_b.size(0); + const int64_t N = in_b.size(1); + const int64_t K1_half = in_b.size(2); + TORCH_CHECK(K0 * KPerBlock == K, "K0 * KPerBlock != K (", K0, "*", KPerBlock, + "!=", K, ")"); + TORCH_CHECK(K1_half * 2 == KPerBlock, "in_b last dim must be KPerBlock/2 (", + K1_half, "*2 !=", KPerBlock, ")"); + TORCH_CHECK(in_s.size(0) == N && in_s.size(1) * group_size == K, + "in_s shape must be [N, K/G]; got [", in_s.size(0), ",", + in_s.size(1), "] for K=", K, " N=", N, " G=", group_size); + TORCH_CHECK(in_s.is_contiguous(), + "in_s must be contiguous row-major [N, K/G]"); + + auto out = torch::empty({M, N}, in_a.options()); + + const at::cuda::OptionalCUDAGuard guard(device_of(in_a)); + + const ck::index_t StrideA = static_cast(K); + const ck::index_t StrideB = static_cast(K); + const ck::index_t StrideC = static_cast(N); + const ck::index_t Scale_Stride_BN = static_cast(K / group_size); + const ck::index_t KBatch = 1; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( + reinterpret_cast(in_a.data_ptr()), + reinterpret_cast(in_b.data_ptr()), + reinterpret_cast(out.data_ptr()), static_cast(M), + static_cast(N), static_cast(K), StrideA, + StrideB, StrideC, Scale_Stride_BN, + reinterpret_cast(in_s.data_ptr()), KBatch, + PassThrough{}, PassThrough{}, PassThrough{}, + reinterpret_cast(in_scaled_zp.data_ptr())); + + TORCH_CHECK(gemm.IsSupportedArgument(argument), + "CK W4A16 b_scale_zp device op rejected the argument; ", + "shape (M=", M, ", N=", N, ", K=", K, ", G=", group_size, + ") not supported by this build"); + + StreamConfig stream; + stream.stream_id_ = at::cuda::getCurrentCUDAStream(); + invoker.Run(argument, stream); + + return out; +} diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 73783180a607..da2ce1d4ff45 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -26,6 +26,21 @@ torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, const int64_t CuCount, const int64_t group_size); +// AIESW-32176: CK WMMA W4A16 b_scale GEMM. gfx1151-tuned for the Qwen3-4B +// gate_up_proj prefill shape (M=3968, N=19456, K=2560, group=128). +torch::Tensor ck_w4a16_b_scale_gemm(const at::Tensor& in_a, + const at::Tensor& in_b, + const at::Tensor& in_s, int64_t group_size); + +// AIESW-32176: asymmetric (AWQ-style) variant with per-group zero points. +// in_scaled_zp = (zp - 8) * scale precomputed at weight load (same shape and +// stride as in_s). +torch::Tensor ck_w4a16_b_scale_zp_gemm(const at::Tensor& in_a, + const at::Tensor& in_b, + const at::Tensor& in_s, + const at::Tensor& in_scaled_zp, + int64_t group_size); + void fused_moe_wvSplitK_int4_gemm(torch::Tensor a, torch::Tensor w, torch::Tensor scales, torch::Tensor c, torch::Tensor expert_ids, diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 6b6aebd262a3..7a68f290bdbb 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -56,6 +56,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "int group_size) -> Tensor"); rocm_ops.impl("wvSplitK_int4_g", torch::kCUDA, &wvSplitK_int4_g); +#ifdef VLLM_HAVE_CK_W4A16 + // AIESW-32176: CK WMMA W4A16 b_scale GEMM. + // in_b is in CK pk_i4 layout [K0, N, K1/2] int8; in_s is [N, K/G] fp16. + rocm_ops.def( + "ck_w4a16_b_scale_gemm(Tensor in_a, Tensor in_b, Tensor in_s, " + "int group_size) -> Tensor"); + rocm_ops.impl("ck_w4a16_b_scale_gemm", torch::kCUDA, &ck_w4a16_b_scale_gemm); + + // AIESW-32176: asymmetric (AWQ) variant. in_scaled_zp = (zp-8)*scale [N,K/G]. + rocm_ops.def( + "ck_w4a16_b_scale_zp_gemm(Tensor in_a, Tensor in_b, Tensor in_s, " + "Tensor in_scaled_zp, int group_size) -> Tensor"); + rocm_ops.impl("ck_w4a16_b_scale_zp_gemm", torch::kCUDA, + &ck_w4a16_b_scale_zp_gemm); +#endif + // Fused MoE wrapper around wvSplitK_int4_g: iterates expert runs in C++ rocm_ops.def( "fused_moe_wvSplitK_int4_gemm(Tensor a, Tensor w, Tensor scales, " diff --git a/vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py b/vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py index 66434f66e02e..74274ae0cf63 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py @@ -38,6 +38,96 @@ MAX_SKINNY_BATCH_SIZE = 5 LDS_CAPACITY_ELEMENTS = 64 * 1024 // 2 # 32768 fp16 elements +# AIESW-32176: shapes routed to the CK WMMA b_scale GEMM op (gfx1151 only). +# Each entry is keyed by (N, K, group_size, dtype) and maps to (min_M, KPerBlock). +# Dispatch fires when M >= min_M for this layer — the kernel handles any M >= 1, +# but min_M sets a lower bound below which fixed launch overhead (~0.4 ms) dominates +# and Triton is comparable. Above the threshold, CK holds 22-31 TFLOPS uniformly +# across the M dimension (measured M=256-16384 — see AIInfo memory +# project_aiesw_32176_phase5c_shapes). This handles arbitrary chunked-prefill +# chunk sizes including the M=1920 second-chunk case for prompt=3968+chunk=2048. +# All four Qwen3-4B prefill linear columns are wired; the same kernel binary +# handles all shapes (M/N/K are runtime args; only KPerBlock is templated). +# Each wired layer costs an extra weight copy (~0.92 GB total for the four +# Qwen3-4B columns on a 36-layer model). +_CK_W4A16_TARGET_SHAPES: dict[tuple, tuple[int, int]] = { + (19456, 2560, 128, torch.float16): (256, 32), # gate_up_proj + (6144, 2560, 128, torch.float16): (256, 32), # qkv_proj (q=4096+k=1024+v=1024) + (2560, 4096, 128, torch.float16): (256, 32), # o_proj + (2560, 9728, 128, torch.float16): (256, 32), # down_proj +} + + +def _is_gfx1151() -> bool: + """True iff current device is gfx1151 (Strix Halo, compute_cap (11, 5)).""" + if not on_gfx1x(): + return False + try: + return torch.cuda.get_device_capability(0) == (11, 5) + except Exception: + return False + + +def _lookup_ck_target( + N: int, K: int, group_size: int, dtype: torch.dtype +) -> tuple[int, int] | None: + """Find a registered CK target for this layer's (N, K, group, dtype). + Returns (min_M, KPerBlock) if any, else None. Called once per layer at + load time (Python ints, not SymInts) — so dict lookup is safe.""" + if not _is_gfx1151(): + return None + return _CK_W4A16_TARGET_SHAPES.get((N, K, group_size, dtype)) + + +def _has_ck_w4a16_op() -> bool: + """True iff vllm was built with VLLM_CK_INCLUDE_DIR (i.e. the CK op is + registered in _rocm_C). Imported lazily so non-ROCm builds don't pay.""" + try: + return hasattr(torch.ops._rocm_C, "ck_w4a16_b_scale_gemm") + except (AttributeError, RuntimeError): + return False + + +def _has_ck_w4a16_zp_op() -> bool: + """True iff the asymmetric (with-zero-points) CK op is registered. + AIESW-32176 Phase 5b — separate from _has_ck_w4a16_op so older builds + of _rocm_C that only have the symmetric op still work.""" + try: + return hasattr(torch.ops._rocm_C, "ck_w4a16_b_scale_zp_gemm") + except (AttributeError, RuntimeError): + return False + + +def _ck_disabled() -> bool: + """Set VLLM_DISABLE_CK_W4A16=1 to bypass the CK dispatch and stay on Triton. + Used for A/B benchmarking the CK kernel against the Triton baseline without + rebuilding.""" + import os + + return os.environ.get("VLLM_DISABLE_CK_W4A16", "0").strip().lower() in ( + "1", + "true", + "yes", + ) + + +def _repack_vllm_to_ck_b_scale( + w_q_skinny_i32: torch.Tensor, # [N, K//8] int32 + KPerBlock: int, +) -> torch.Tensor: + """vLLM ExLlama [N, K//8] int32 -> CK pk_i4 [K0, N, K1//2] int8. Pure + reshape + axis swap (nibble shuffle is byte-identical). Scales pass through + unchanged — CK's b1_k_n is a stride-quirk view over [N, K/G] row-major bytes.""" + N, K_div_8 = w_q_skinny_i32.shape + K = K_div_8 * 8 + K0 = K // KPerBlock + return ( + w_q_skinny_i32.reshape(N, K0, KPerBlock // 8) + .permute(1, 0, 2) + .contiguous() + .view(torch.int8) + ) + # --------------------------------------------------------------------------- # Triton kernel for the prefill path (reads skinny-format weights [N, K//8]) @@ -332,10 +422,13 @@ def _hybrid_w4a16_apply_impl( bias: torch.Tensor | None, cu_count: int, group_size: int, + w_q_ck: torch.Tensor | None = None, + ck_min_m: int = 0, + w_scaled_zp_ck: torch.Tensor | None = None, ) -> torch.Tensor: - """Dispatch between skinny GEMM and Triton based on batch size M. + """Dispatch between skinny GEMM, CK W4A16 b_scale (sym/asym), and Triton. - Both paths read from the same skinny-format weights: + Both skinny and Triton paths read from the same vLLM skinny-format weights: w_q: [N, K//8] int8 (ExLlama shuffle, for skinny kernel) w_q_i32: [N, K//8] int32 (same data viewed as int32, for triton) w_s: [N, K//G] fp16/bf16 (skinny-layout scales) @@ -343,6 +436,13 @@ def _hybrid_w4a16_apply_impl( or None for symmetric. Both HIP skinny and Triton use this single format: dequant = (nibble - zp_raw) * scale. + AIESW-32176: w_q_ck is the same weights repacked into CK pk_i4 layout + [K0, N, K1//2] int8. When non-None and M >= ck_min_m the CK + GEMM kernel is used instead of Triton: + - symmetric (w_zp is None): ck_w4a16_b_scale_gemm + - asymmetric (w_zp set, w_scaled_zp_ck = (zp-8)*scale precomputed at + load time): ck_w4a16_b_scale_zp_gemm + Registered as a custom op so torch.compile treats it as opaque. """ import vllm._custom_ops as ops @@ -362,6 +462,41 @@ def _hybrid_w4a16_apply_impl( with ctx: return ops.wvSplitK_int4_g(w_q, x_2d, w_s, cu_count, group_size, w_zp, bias) + # AIESW-32176: CK W4A16 b_scale path (sym or asym). Conditional is inside + # the custom op so it's opaque to dynamo and the runtime M check is a plain + # Python int compare against the per-layer min-M threshold. + if w_q_ck is not None and ck_min_m > 0 and ck_min_m <= M: + ctx = ( + nullcontext() + if torch.compiler.is_compiling() + else torch.profiler.record_function(f"ck_w4a16 {M}x{N}x{K}") + ) + with ctx: + if w_zp is None: + output = torch.ops._rocm_C.ck_w4a16_b_scale_gemm( + x_2d, + w_q_ck, + w_s, + group_size, + ) + elif w_scaled_zp_ck is not None: + output = torch.ops._rocm_C.ck_w4a16_b_scale_zp_gemm( + x_2d, + w_q_ck, + w_s, + w_scaled_zp_ck, + group_size, + ) + else: + # Asymmetric layer with zp present but scaled_zp not precomputed + # — fall through to Triton (shouldn't happen if load-time path + # is wired, but defensive). + output = None + if output is not None: + if bias is not None: + output.add_(bias) + return output + ctx = ( nullcontext() if torch.compiler.is_compiling() @@ -389,6 +524,9 @@ def _hybrid_w4a16_apply_fake( bias: torch.Tensor | None, cu_count: int, group_size: int, + w_q_ck: torch.Tensor | None = None, + ck_min_m: int = 0, + w_scaled_zp_ck: torch.Tensor | None = None, ) -> torch.Tensor: M = x_2d.size(0) N = w_q.size(0) @@ -519,6 +657,47 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: torch.nn.Parameter(w_q_skinny_i32, requires_grad=False), ) + # AIESW-32176: precompute CK b_scale layout if this layer's (N, K, group, + # dtype) matches a registered CK target shape on gfx1151. Done once at + # load time with regular Python ints (not SymInts), so the lookup is safe + # outside the dynamo trace. The CK kernel is symmetric-only; skip if zp. + if _has_ck_w4a16_op() and not _ck_disabled(): + N = w_q_skinny_i32.shape[0] + K = w_q_skinny_i32.shape[1] * 8 + target = _lookup_ck_target(N, K, c.group_size, c.act_type) + # Symmetric: just need _hybrid_w_q_ck. Asymmetric: also need + # _hybrid_w_scaled_zp_ck = (zp - 8) * scale [N, K/G] precomputed + # once, AND the asymmetric CK op must be present in this build. + asym_ok = (not c.zero_points) or _has_ck_w4a16_zp_op() + if target is not None and asym_ok: + min_M, kperblock = target + w_q_ck = _repack_vllm_to_ck_b_scale(w_q_skinny_i32, kperblock) + layer.register_parameter( + "_hybrid_w_q_ck", + torch.nn.Parameter(w_q_ck, requires_grad=False), + ) + # Plain Python int — safe to compare against SymInt M at apply. + layer._hybrid_ck_min_M = int(min_M) + + if c.zero_points: + # AIESW-32176: precompute scaled_zp = (zp - 8) * scale. + # zp is stored on the layer post-process as raw fp16 in + # act dtype (see the c.zero_points block above). scale + # here is w_s_skinny [N, K/G]. Result shape matches. + w_zp_raw = getattr(layer, self.w_zp_name).data + scaled_zp = ( + ( + (w_zp_raw.to(torch.float32) - 8.0) + * w_s_skinny.to(torch.float32) + ) + .to(c.act_type) + .contiguous() + ) + layer.register_parameter( + "_hybrid_w_scaled_zp_ck", + torch.nn.Parameter(scaled_zp, requires_grad=False), + ) + def apply_weights( self, layer: torch.nn.Module, @@ -535,6 +714,13 @@ def apply_weights( N = w_q.shape[0] out_shape = x.shape[:-1] + (N,) + # AIESW-32176: pass CK-format weights + min M (and scaled_zp for + # asymmetric) to the custom op if registered for this layer. Dispatch + # decision happens INSIDE the custom op (opaque to dynamo). + w_q_ck = getattr(layer, "_hybrid_w_q_ck", None) + ck_min_m = getattr(layer, "_hybrid_ck_min_M", 0) + w_scaled_zp_ck = getattr(layer, "_hybrid_w_scaled_zp_ck", None) + cu_count = num_compute_units() output = torch.ops.vllm.hybrid_w4a16_apply( x_2d, @@ -545,5 +731,8 @@ def apply_weights( bias, cu_count, c.group_size, + w_q_ck, + ck_min_m, + w_scaled_zp_ck, ) return output.reshape(out_shape)