Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,32 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
"csrc/rocm/skinny_gemms_w8a8/instantiate_n5.cu"
"csrc/rocm/attention.cu")

# AIESW-32176: include CK W4A16 b_scale GEMM source if both:
# -DVLLM_CK_INCLUDE_DIR=/path/to/composablekernel/include (source headers, has ck/ck.hpp)
# -DVLLM_CK_BUILD_INCLUDE_DIR=/path/to/ck-build/include (generated headers, has ck/config.h)
# are set. CK's config.h is generated by CK's own CMake configure (per-build flags),
# so we need both the source tree and the build tree on the include path.
# The dispatcher in hybrid_w4a16.py routes only the tuned gfx1151 Qwen3-4B gate_up_proj
# prefill shape (M=3968, N=19456, K=2560, group=128) to this op; all other shapes
# stay on the existing Triton path.
if(VLLM_CK_INCLUDE_DIR AND VLLM_CK_BUILD_INCLUDE_DIR)
if(NOT EXISTS "${VLLM_CK_INCLUDE_DIR}/ck/ck.hpp")
message(FATAL_ERROR
"VLLM_CK_INCLUDE_DIR=${VLLM_CK_INCLUDE_DIR} does not contain ck/ck.hpp")
endif()
if(NOT EXISTS "${VLLM_CK_BUILD_INCLUDE_DIR}/ck/config.h")
message(FATAL_ERROR
"VLLM_CK_BUILD_INCLUDE_DIR=${VLLM_CK_BUILD_INCLUDE_DIR} does not contain ck/config.h "
"(this is generated by CK's CMake configure step; point at /<ck-build>/include)")
endif()
list(APPEND VLLM_ROCM_EXT_SRC "csrc/rocm/ck_w4a16.cu")
message(STATUS "AIESW-32176: building csrc/rocm/ck_w4a16.cu against CK source ${VLLM_CK_INCLUDE_DIR} + build ${VLLM_CK_BUILD_INCLUDE_DIR}")
elseif(VLLM_CK_INCLUDE_DIR OR VLLM_CK_BUILD_INCLUDE_DIR)
message(FATAL_ERROR "AIESW-32176: must set BOTH VLLM_CK_INCLUDE_DIR and VLLM_CK_BUILD_INCLUDE_DIR (or neither)")
else()
message(STATUS "AIESW-32176: VLLM_CK_*_INCLUDE_DIR not set; csrc/rocm/ck_w4a16.cu skipped (CK W4A16 dispatch will fall through to Triton)")
endif()

set(VLLM_ROCM_FLAGS ${VLLM_GPU_FLAGS})

if(VLLM_SKINNY_GEMM_SWEEP OR (DEFINED ENV{VLLM_SKINNY_GEMM_SWEEP} AND NOT $ENV{VLLM_SKINNY_GEMM_SWEEP} STREQUAL "0"))
Expand All @@ -1253,6 +1279,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
if(_ROCM_SWEEP_ENABLED)
target_compile_definitions(_rocm_C PRIVATE VLLM_SKINNY_GEMM_SWEEP=1)
endif()

if(VLLM_CK_INCLUDE_DIR AND VLLM_CK_BUILD_INCLUDE_DIR)
target_include_directories(_rocm_C PRIVATE ${VLLM_CK_INCLUDE_DIR} ${VLLM_CK_BUILD_INCLUDE_DIR})
target_compile_definitions(_rocm_C PRIVATE VLLM_HAVE_CK_W4A16=1)
endif()
endif()

# For CUDA and HIP builds also build the triton_kernels external package.
Expand Down
58 changes: 42 additions & 16 deletions benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
def prepare_hybrid_weights(K, N, group_size, device="cuda"):
"""Create random weights for benchmarking.

Returns (w_q_skinny, w_s_skinny, w_fp16, w_q_skinny_i32).
Returns (w_q_skinny, w_s_skinny, w_s_skinny_bf16, w_fp16, w_bf16,
w_q_skinny_i32, w_zp).
"""
num_groups = K // group_size

Expand All @@ -53,52 +54,68 @@ def prepare_hybrid_weights(K, N, group_size, device="cuda"):
)
w_q_skinny = w_q_skinny_i32.view(torch.int8).contiguous()
w_s_skinny = torch.randn(N, num_groups, dtype=torch.float16, device=device) * 0.01
w_s_skinny_bf16 = w_s_skinny.to(torch.bfloat16)

# Raw per-group zero-points for asymmetric benchmarks
w_zp = torch.randint(0, 16, (N, num_groups), dtype=torch.int32, device=device).to(
torch.float16
)

# FP16 baseline for F.linear
# FP16 / BF16 baselines for F.linear
w_fp16 = torch.randn(N, K, dtype=torch.float16, device=device) * 0.01

return w_q_skinny, w_s_skinny, w_fp16, w_q_skinny_i32, w_zp
w_bf16 = w_fp16.to(torch.bfloat16)

return (
w_q_skinny,
w_s_skinny,
w_s_skinny_bf16,
w_fp16,
w_bf16,
w_q_skinny_i32,
w_zp,
)


# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------
PROVIDERS = ["torch-fp16", "hybrid-w4a16", "hybrid-w4a16-zp"]
PROVIDERS = [
"torch-fp16",
"torch-bf16",
"hybrid-w4a16",
"hybrid-w4a16-bf16",
"hybrid-w4a16-zp",
]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 3968, 4096],
x_log=False,
line_arg="provider",
line_vals=PROVIDERS,
line_names=PROVIDERS,
ylabel="TFLOP/s (larger is better)",
plot_name="FP16 vs Hybrid W4A16",
plot_name="fp16_bf16_vs_hybrid_w4a16",
args={},
)
)
def benchmark(batch_size, provider, N, K, group_size, weights):
M = batch_size
device = "cuda"
dtype = torch.float16
dtype = torch.bfloat16 if provider.endswith("bf16") else torch.float16
a = torch.randn((M, K), device=device, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]

if provider == "torch-fp16":
w_fp16 = weights["w_fp16"]
if provider in ("torch-fp16", "torch-bf16"):
w = weights["w_fp16" if dtype == torch.float16 else "w_bf16"]
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, w_fp16),
lambda: torch.nn.functional.linear(a, w),
quantiles=quantiles,
)
elif provider in ("hybrid-w4a16", "hybrid-w4a16-zp"):
elif provider in ("hybrid-w4a16", "hybrid-w4a16-bf16", "hybrid-w4a16-zp"):
from vllm.model_executor.kernels.linear.mixed_precision.hybrid_w4a16 import (
_hybrid_w4a16_apply_impl,
)
Expand All @@ -107,12 +124,13 @@ def benchmark(batch_size, provider, N, K, group_size, weights):
w = weights
cu_count = num_compute_units()
use_zp = provider == "hybrid-w4a16-zp"
scales_key = "w_s_skinny_bf16" if dtype == torch.bfloat16 else "w_s_skinny"

def run():
return _hybrid_w4a16_apply_impl(
a,
w["w_q_skinny"],
w["w_s_skinny"],
w[scales_key],
w["w_q_skinny_i32"],
w["w_zp"] if use_zp else None,
None, # bias
Expand Down Expand Up @@ -161,14 +179,22 @@ def prepare_shapes(args):
print(f"{model}, N={N} K={K}, group_size={group_size}")
print(f"{'=' * 70}")

w_q_skinny, w_s_skinny, w_fp16, w_q_skinny_i32, w_zp = prepare_hybrid_weights(
K, N, group_size
)
(
w_q_skinny,
w_s_skinny,
w_s_skinny_bf16,
w_fp16,
w_bf16,
w_q_skinny_i32,
w_zp,
) = prepare_hybrid_weights(K, N, group_size)

weights = {
"w_q_skinny": w_q_skinny,
"w_s_skinny": w_s_skinny,
"w_s_skinny_bf16": w_s_skinny_bf16,
"w_fp16": w_fp16,
"w_bf16": w_bf16,
"w_q_skinny_i32": w_q_skinny_i32,
"w_zp": w_zp,
}
Expand Down
Loading
Loading