Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,43 @@ gfx950,256,4096,57344,8192,torch.float8_e4m3fn,flydsl,979,0,1486.3101,flydsl_bpr
gfx950,256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2638.66,506.16,0.0
gfx950,256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2x0_default,2609.01,420.85,0.0
gfx950,256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0
gfx1250,256,16,2112,7168,torch.float8_e4m3fn,flydsl,25,0,30.8522,flydsl_bpreshuffle_wmma_t32x32x256_mw2_nw2_nb3_sk1_cm1_cn1,15.7,496.6,0.011
gfx1250,256,32,2112,7168,torch.float8_e4m3fn,flydsl,286,0,28.9229,flydsl_bpreshuffle_wmma_t32x64x256_mw2_nw2_nb4_sk1_cm1_cn1,33.5,536.02,0.0105
gfx1250,256,64,2112,7168,torch.float8_e4m3fn,flydsl,269,0,31.7564,flydsl_bpreshuffle_wmma_t32x32x256_mw1_nw2_nb4_sk1_cm1_cn1,61.02,499.68,0.0098
gfx1250,256,128,2112,7168,torch.float8_e4m3fn,flydsl,9,0,8.326,flydsl_bpreshuffle_wmma_t32x64x512_mw2_nw2_nb4_sk1_cm1_cn1,465.47,1993.39,0.0096
gfx1250,256,256,2112,7168,torch.float8_e4m3fn,flydsl,19,0,25.7514,flydsl_bpreshuffle_wmma_t128x192x128_mw2_nw2_nb4_sk1_cm1_cn1,301.0,701.13,0.0098
gfx1250,256,512,2112,7168,torch.float8_e4m3fn,flydsl,398,0,40.0855,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,386.73,523.17,0.0098
gfx1250,256,1024,2112,7168,torch.float8_e4m3fn,flydsl,398,0,68.7538,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,450.95,389.86,0.0096
gfx1250,256,2048,2112,7168,torch.float8_e4m3fn,flydsl,398,0,126.0903,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,491.78,305.1,0.0098
gfx1250,256,4096,2112,7168,torch.float8_e4m3fn,flydsl,398,0,240.7634,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,515.1,256.69,0.0098
gfx1250,256,8192,2112,7168,torch.float8_e4m3fn,flydsl,398,0,470.1095,flydsl_bpreshuffle_wmma_t256x192x128_mw2_nw2_nb4_sk1_cm1_cn1,527.61,230.72,0.0098
gfx1250,256,16,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.8317,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,118.06,3704.86,0.0237
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tileM256 for M16?

gfx1250,256,32,7168,16384,torch.float8_e4m3fn,flydsl,500,0,31.9644,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,235.14,3704.86,0.0237
gfx1250,256,64,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.2297,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,466.41,3704.86,0.0237
gfx1250,256,128,7168,16384,torch.float8_e4m3fn,flydsl,500,0,32.7604,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,917.72,3704.86,0.0237
gfx1250,256,256,7168,16384,torch.float8_e4m3fn,flydsl,500,0,33.8217,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,1777.84,3704.86,0.0237
gfx1250,256,512,7168,16384,torch.float8_e4m3fn,flydsl,500,0,51.8551,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2319.14,2568.1,0.0237
gfx1250,256,1024,7168,16384,torch.float8_e4m3fn,flydsl,500,0,92.0588,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm2_cn2,2612.66,1617.42,0.0237
gfx1250,256,2048,7168,16384,torch.float8_e4m3fn,flydsl,302,0,172.4662,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2789.16,1045.74,0.0237
gfx1250,256,4096,7168,16384,torch.float8_e4m3fn,flydsl,302,0,327.0773,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2941.42,743.77,0.0236
gfx1250,256,8192,7168,16384,torch.float8_e4m3fn,flydsl,302,0,640.4273,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,3004.47,576.33,0.0237
gfx1250,256,16,24576,1536,torch.float8_e4m3fn,flydsl,302,0,38.8569,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,31.09,992.35,0.0
gfx1250,256,32,24576,1536,torch.float8_e4m3fn,flydsl,302,0,39.2634,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,61.53,1002.73,0.0
gfx1250,256,64,24576,1536,torch.float8_e4m3fn,flydsl,302,0,40.0765,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,120.57,1022.86,0.0
gfx1250,256,128,24576,1536,torch.float8_e4m3fn,flydsl,302,0,41.7027,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,231.73,1060.76,0.0
gfx1250,256,256,24576,1536,torch.float8_e4m3fn,flydsl,302,0,44.9552,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,429.92,1128.34,0.0
gfx1250,256,512,24576,1536,torch.float8_e4m3fn,flydsl,302,0,51.4601,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,751.16,1237.87,0.0
gfx1250,256,1024,24576,1536,torch.float8_e4m3fn,flydsl,302,0,64.4699,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1199.16,1390.62,0.0
gfx1250,256,2048,24576,1536,torch.float8_e4m3fn,flydsl,302,0,90.4895,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1708.69,1564.36,0.0
gfx1250,256,4096,24576,1536,torch.float8_e4m3fn,flydsl,302,0,146.1362,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2116.09,1679.03,0.0
gfx1250,256,8192,24576,1536,torch.float8_e4m3fn,flydsl,302,0,257.4906,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,2401.93,1759.23,0.0
gfx1250,256,16,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.5724,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,11.53,382.93,0.0
gfx1250,256,32,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.638,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,23.02,405.05,0.0
gfx1250,256,64,32768,512,torch.float8_e4m3fn,flydsl,302,0,46.7691,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,45.92,449.11,0.0
gfx1250,256,128,32768,512,torch.float8_e4m3fn,flydsl,302,0,47.0315,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,91.32,536.48,0.0
gfx1250,256,256,32768,512,torch.float8_e4m3fn,flydsl,302,0,47.5561,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,180.63,708.33,0.0
gfx1250,256,512,32768,512,torch.float8_e4m3fn,flydsl,302,0,48.6054,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,353.46,1040.91,0.0
gfx1250,256,1024,32768,512,torch.float8_e4m3fn,flydsl,302,0,50.7041,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,677.65,1664.76,0.0
gfx1250,256,2048,32768,512,torch.float8_e4m3fn,flydsl,302,0,54.9014,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1251.69,2769.39,0.0
gfx1250,256,4096,32768,512,torch.float8_e4m3fn,flydsl,302,0,88.9453,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1545.21,3230.19,0.0
gfx1250,256,8192,32768,512,torch.float8_e4m3fn,flydsl,302,0,150.5706,flydsl_bpreshuffle_wmma_t256x256x128_mw2_nw2_nb4_sk1_cm1_cn1,1825.57,3704.86,0.0
105 changes: 55 additions & 50 deletions aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

"""gfx1250 (WMMA) backend for the FlyDSL a8w8 bpreshuffle GEMM.

aiter.gemm_a8w8_bpreshuffle routes here when its tuned kernelName starts with
``flydsl_bpreshuffle_wmma_`` (gfx1250 has no MFMA preshuffle kernel). Runs the
vendored gemm_fp8fp4_gfx1250 WMMA kernel in ptpc scale mode: C = (A*sa) @ (B*sb)^T
with fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue. N/K are not
padded (must divide the tile); M is padded to tile_m when ragged, since the kernel
reads a full tile_m rows per workgroup (M=1 would otherwise read past A/sa -> NaN).
aiter.gemm_a8w8_bpreshuffle's FlyDSL path runs here on gfx1250 (no MFMA preshuffle
kernel); the tuned kernelName (prefix ``flydsl_bpreshuffle_wmma_``) encodes the tile
config. Computes C = (A*sa) @ (B*sb)^T via the vendored gemm_fp8fp4_gfx1250 WMMA
kernel, fp32 per-token sa[M] / per-channel sb[N] applied in the epilogue.

N/K must divide the tile; M may be ragged (no host padding) — the kernel clips
loads/stores to the runtime M via hardware OOB, predicating split-k's atomic add
per-lane on row < M. A/C may be strided (lda/ldc passed at runtime, no copy) when
the inner dim is unit-stride; B is preshuffled into its own contiguous buffer.
"""

from __future__ import annotations
Expand All @@ -26,6 +29,7 @@
_WMMA_K = 128
_SUPPORTED_NUM_BUFFERS = (2, 3, 4)
_OUT_DTYPE_NAME = {torch.bfloat16: "bf16", torch.float16: "f16"}
_MAX_SPLIT_K = 4


def _lazy_import():
Expand Down Expand Up @@ -54,10 +58,6 @@ def _as_1d_fp32(scale: Tensor, length: int, name: str) -> Tensor:
return flat.contiguous()


def _to_uint8(t: Tensor) -> Tensor:
return t.contiguous().view(torch.uint8).view(-1)


def run_preshuffle_gemm_a8_gfx1250(
XQ: Tensor,
WQ: Tensor,
Expand Down Expand Up @@ -113,33 +113,36 @@ def run_preshuffle_gemm_a8_gfx1250(
cluster_m = max(1, int(cluster_m))
cluster_n = max(1, int(cluster_n))

accumulate_fp32 = split_k > 1
kernel_out_dtype = "f32" if accumulate_fp32 else out_dtype
if split_k > _MAX_SPLIT_K:
raise RuntimeError(
f"[FlyDSL gfx1250] split_k={split_k} exceeds the bf16/f16 atomic-add "
f"precision cap of {_MAX_SPLIT_K}"
)

# Pipeline depth needs >= 1 K tile per buffer (per split-k chunk).
num_k_tiles = (K // split_k) // tile_k
nb = max(2, min(int(num_buffers), num_k_tiles))
# Validate (tuned names always pass); fail loudly rather than silently clamp.
nb = int(num_buffers)
if nb not in _SUPPORTED_NUM_BUFFERS:
nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb)
raise RuntimeError(
f"[FlyDSL gfx1250] num_buffers must be one of {_SUPPORTED_NUM_BUFFERS}, "
f"got {nb}"
)
if K % (split_k * tile_k) != 0:
raise RuntimeError(
f"[FlyDSL gfx1250] K={K} must be divisible by split_k*tile_k="
f"{split_k}*{tile_k}={split_k * tile_k}"
)
# Each split-k chunk must hold >= num_buffers K-tiles to fill the pipeline.
num_k_tiles = (K // split_k) // tile_k
if num_k_tiles < nb:
raise RuntimeError(
f"[FlyDSL gfx1250] {nb}-buffer pipeline needs >= {nb} K-tiles per "
f"split-k chunk, got {num_k_tiles} (K={K}, split_k={split_k}, tile_k={tile_k})"
)

sa = _as_1d_fp32(x_scale, M, "x_scale")
sb = _as_1d_fp32(w_scale, N, "w_scale")

# M padded to tile_m when ragged (kernel reads a full tile_m rows/wg).
padded_m = ((M + tile_m - 1) // tile_m) * tile_m
if padded_m == M:
a_dev = XQ.contiguous()
sa_dev = sa
else:
a_dev = torch.zeros((padded_m, K), dtype=XQ.dtype, device=XQ.device)
a_dev[:M] = XQ
sa_dev = torch.ones((padded_m,), dtype=torch.float32, device=sa.device)
sa_dev[:M] = sa

b_dev = WQ.contiguous()

exe = _compile_ptpc_gemm(
M=padded_m,
N=N,
K=K,
data_format="fp8",
Expand All @@ -152,40 +155,37 @@ def run_preshuffle_gemm_a8_gfx1250(
waves_per_eu=(None if waves_per_eu <= 0 else waves_per_eu),
cluster_m=cluster_m,
cluster_n=cluster_n,
out_dtype=kernel_out_dtype,
out_dtype=out_dtype,
split_k=split_k,
)

if accumulate_fp32:
# fp32 atomic-accumulation scratch (zeroed; narrowed into Out below).
out_buf = torch.zeros((padded_m, N), dtype=torch.float32, device=Out.device)
elif padded_m == M:
out_buf = Out.contiguous()
else:
out_buf = torch.empty((padded_m, N), dtype=Out.dtype, device=Out.device)
lda = XQ.stride(0)
ldc = Out.stride(0)
if split_k > 1:
Out.zero_() # split-k atomic-accumulates into Out

stream = _fx.Stream(torch.cuda.current_stream(device=a_dev.device))
stream = _fx.Stream(torch.cuda.current_stream(device=XQ.device))
_run_compiled(
exe,
out_buf.view(-1),
_to_uint8(a_dev),
_to_uint8(b_dev),
sa_dev.contiguous().view(-1),
sb.contiguous().view(-1),
padded_m,
Out,
XQ.view(torch.uint8),
WQ.view(torch.uint8),
sa.view(-1),
sb.view(-1),
M,
N,
lda,
ldc,
stream,
)

if out_buf.data_ptr() != Out.data_ptr():
Out.copy_(out_buf[:M])
return Out


# flydsl_bpreshuffle_wmma_t{tm}x{tn}x{tk}_nb{nb}_sk{sk}_cm{cm}_cn{cn}
# flydsl_bpreshuffle_wmma_t{tm}x{tn}x{tk}_mw{mw}_nw{nw}_nb{nb}_sk{sk}_cm{cm}_cn{cn}
_KERNEL_NAME_RE = re.compile(
r"^flydsl_bpreshuffle_wmma_"
r"t(?P<tile_m>\d+)x(?P<tile_n>\d+)x(?P<tile_k>\d+)_"
r"mw(?P<m_warp>\d+)_nw(?P<n_warp>\d+)_"
r"nb(?P<num_buffers>\d+)_sk(?P<split_k>\d+)_"
r"cm(?P<cluster_m>\d+)_cn(?P<cluster_n>\d+)$"
)
Expand All @@ -200,10 +200,13 @@ def wmma_kernel_name(
split_k: int,
cluster_m: int,
cluster_n: int,
m_warp: int,
n_warp: int,
) -> str:
return (
f"flydsl_bpreshuffle_wmma_t{tile_m}x{tile_n}x{tile_k}_"
f"nb{num_buffers}_sk{split_k}_cm{cluster_m}_cn{cluster_n}"
f"mw{m_warp}_nw{n_warp}_nb{num_buffers}_sk{split_k}_"
f"cm{cluster_m}_cn{cluster_n}"
)


Expand Down Expand Up @@ -238,4 +241,6 @@ def run_gemm_a8w8_bpreshuffle_gfx1250(
split_k=cfg["split_k"],
cluster_m=cfg["cluster_m"],
cluster_n=cfg["cluster_n"],
m_warp=cfg["m_warp"],
n_warp=cfg["n_warp"],
)
Loading
Loading