diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py index ffb81dff63..9c04e17508 100644 --- a/aiter/tuned_gemm.py +++ b/aiter/tuned_gemm.py @@ -64,6 +64,7 @@ def get_GEMM_A16W16_config_(): gemm_dict = pd.read_csv(f"{tuned_file}").drop_duplicates() gemm_dict = gemm_dict.set_index( [ + "gfx", "cu_num", "M", "N", @@ -117,11 +118,12 @@ def get_GEMM_A16W16_config( cu_num = get_cu_num() padded_M = M config = None - + gfx = get_gfx() for gl in [None, 0, 1]: padded_M = M if gl is None else get_padded_m(M, N, K, gl) config = cfg.get( ( + gfx, cu_num, padded_M, N, @@ -161,7 +163,6 @@ def get_GEMM_A16W16_config( if config is None: default_config = {} - gfx = get_gfx() # gfx12: no ASM/skinny/hipblaslt kernels, use torch if gfx.startswith("gfx12"): default_config["libtype"] = "torch"