diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index a8e8facebb..25651f0b4e 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -366,10 +366,17 @@ def weight_per_128x128_quant(weight, quant_dtype): num_iters=5, num_warmup=2, ) + # Compare only the real (un-padded) model_dim region. The padded tail + # (last hidden_pad cols) carries no meaningful output and some kernels + # leave it uninitialized/NaN, which would poison checkAllclose/logits_diff. + real_model_dim = model_dim - hidden_pad + out2_ref = out2_ref[:, :real_model_dim] + out2_ck = out2_ck[:, :real_model_dim] + err = checkAllclose( out2_ref, out2_ck, - msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + msg=f"ck_moe_2stages:{us2:>8.2f} us, {token * model_dim * inter_dim * 3 * topk * 2 / us2 / 1000 / 1000:>8.2f} tflops......(quant:{AQDType})", ) def calc_diff(x: torch.Tensor, y: torch.Tensor): @@ -431,8 +438,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-dim", type=dtypes.str2tuple, nargs="*", - default=[(7168, 256)], - help="""Model dimension. + default=None, + help="""Model dimension. If unset, each quant uses its per-quant default + from QUANT_DEFAULTS; if set, the value applies to ALL selected quants. e.g.: -dim 6144,4096""", ) @@ -502,8 +510,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-e", "--expert", type=int, - default=257, - help="""Number of experts. + default=None, + help="""Number of experts. If unset, each quant uses its per-quant default. e.g.: -e 8""", ) @@ -511,8 +519,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-k", "--topk", type=int, - default=9, - help="""Number of top experts. + default=None, + help="""Number of top experts. If unset, each quant uses its per-quant default. e.g.: -k 2""", ) @@ -521,8 +529,10 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "--preshuffle", type=dtypes.str2bool, nargs="*", - default=[True], - help="""Whether to use pre-shuffle weight mode. Default is [False, True]. + default=None, + help="""Whether to use pre-shuffle weight mode. If unset, each quant uses + its per-quant default (only a4w4 varies preshuffle; others require shuffled + weights for correctness). -p f # False. -p t # True.""", ) @@ -531,8 +541,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "--hidden_intermediate_pad", type=dtypes.str2tuple, nargs="*", - default=[(192, 128)], - help="""Hidden intermediate pad. + default=None, + help="""Hidden intermediate pad. If unset, each quant uses its per-quant default. e.g.: -hip 0,0""", ) parser.add_argument( @@ -665,10 +675,13 @@ def _iter_csv_cases(): continue kwargs["strict_accuracy"] = True kwargs["check_aot_cache"] = True - yield kwargs, { - "kernelName1": kernel_name1, - "kernelName2": kernel_name2, - } + yield ( + kwargs, + { + "kernelName1": kernel_name1, + "kernelName2": kernel_name2, + }, + ) _PER1X32_BF16_FP4 = (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2) @@ -677,6 +690,118 @@ def _iter_csv_cases(): _PER1X32_BF16_I4 = (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.i4x2) +# Per-quant defaults for the legacy CLI-driven sweep. Each entry pins a +# representative shape taken from a tuned config so the smoke test exercises +# the kernel path that actually runs in production. CLI flags +# (-dim/-e/-k/-hip/-p) override these defaults globally when supplied. +# +# Fields: +# dims: list of (model_dim, inter_dim); CLI -dim overrides +# E: expert count; CLI -e overrides +# topk: topk; CLI -k overrides +# pads: list of (hidden_pad, intermediate_pad); CLI -hip overrides +# preshuffles: list of bool; CLI -p overrides +# acts: None = use CLI -a (defaults to [Silu]); explicit list = kernel +# forces this activation and -a cannot override +# strict_accuracy: if True, accuracy mismatch asserts; otherwise warn-only. +# Enabled for the two fp4-weight a4w4 / a8w4-mxfp paths whose +# accuracy we actively gate; other quants only warn. +QUANT_DEFAULTS = { + # a16w16, dsv3 expert MLP + (aiter.QuantType.No, None, None): { + "dims": [(2048, 192)], + "E": 128, + "topk": 8, + "pads": [(0, 0)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": False, + }, + # a8w8 perTensor, dsv3 + (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8): { + "dims": [(7168, 256)], + "E": 256, + "topk": 8, + "pads": [(0, 0)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": False, + }, + # a8w8 perToken, dsv3 + (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8): { + "dims": [(5120, 1024)], + "E": 128, + "topk": 1, + "pads": [(0, 0)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": False, + }, + # a8w4 (int4 weight quant) — no tuned CSV row; keep legacy main default + (aiter.QuantType.per_Token, dtypes.fp8, torch.int4): { + "dims": [(3072, 512)], + "E": 128, + "topk": 4, + "pads": [(0, 0)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": False, + }, + # a4w4, gptoss + _PER1X32_FP4_FP4: { + "dims": [(3072, 512)], + "E": 128, + "topk": 4, + "pads": [(192, 152)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": True, + }, + # a8w8 blockscale (per_128x128 == CSV per_1x128), minimax + (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8): { + "dims": [(3072, 1536)], + "E": 256, + "topk": 8, + "pads": [(0, 0)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": False, + }, + # a16w4, gpt-oss fp8fp4 shape (a16w4 falls into the same dispatch + # for small-batch Swiglu); kernel forces Swiglu. + _PER1X32_BF16_FP4: { + "dims": [(3072, 3072)], + "E": 128, + "topk": 4, + "pads": [(192, 192)], + "preshuffles": [True], + "acts": [aiter.ActivationType.Swiglu], + "strict_accuracy": False, + }, + # a8w4 mxfp (mxfp8 act + mxfp4 weight), dsv4 + _PER1X32_FP8_FP4: { + "dims": [(7168, 512)], + "E": 385, + "topk": 7, + "pads": [(128, 128)], + "preshuffles": [True], + "acts": None, + "strict_accuracy": True, + }, + # a16wi4 (packed int4 weight), kimik2_i4; kernel only supports Silu and + # the FlyDSL a16wi4 path does not wire pad parameters. + _PER1X32_BF16_I4: { + "dims": [(7168, 256)], + "E": 384, + "topk": 8, + "pads": [(0, 0)], + "preshuffles": [True], + "acts": [aiter.ActivationType.Silu], + "strict_accuracy": False, + }, +} + + def _effective_gate_mode(aq_dtype, wq_dtype): if aq_dtype in [dtypes.fp8, dtypes.bf16] and wq_dtype == dtypes.fp4x2: return GateMode.INTERLEAVE.value @@ -710,126 +835,65 @@ def _runtime_swiglu_mxfp4_q_dtype_a( def _iter_legacy_cases(): - """Yield (kwargs, extras) for the original CLI-driven sweep.""" + """Yield (kwargs, extras) for the per-quant smoke sweep. + + Each selected quant uses its `QUANT_DEFAULTS` entry for shape / E / topk / + pad / preshuffle / act unless overridden by CLI flags. Quants whose + `acts` entry is an explicit list have a kernel-imposed activation + restriction (a16w4 → Swiglu, a16wi4 → Silu) and ignore `-a`. + """ extras = {"model": "legacy"} - def _kw( - dtype, - m, - model_dim, - inter_dim, - quant_type, - aq_dtype, - wq_dtype, - doweight_stage1, - act_type, - **over, + for dtype, triple, doweight_stage1 in itertools.product( + args.dtype, l_quant, args.doweight_stage1 ): - return dict( - dtype=dtype, - token=m, - model_dim=model_dim, - inter_dim=inter_dim, - E=args.expert, - topk=args.topk, - actType=act_type, - gateMode=_effective_gate_mode(aq_dtype, wq_dtype), - qType=quant_type, - AQDType=aq_dtype, - WQDType=wq_dtype, - use_g1u1=True, - doweight_stage1=doweight_stage1, - strict_accuracy=False, - check_aot_cache=False, - **over, + cfg = QUANT_DEFAULTS[triple] + quant_type, aq_dtype, wq_dtype = triple + + dims = args.dim if args.dim is not None else cfg["dims"] + pads = ( + args.hidden_intermediate_pad + if args.hidden_intermediate_pad is not None + else cfg["pads"] ) - - for ( - dtype, - (quant_type, aq_dtype, wq_dtype), - (model_dim, inter_dim), - doweight_stage1, - ) in itertools.product(args.dtype, l_quant, args.dim, args.doweight_stage1): - triple = (quant_type, aq_dtype, wq_dtype) - - if triple == _PER1X32_BF16_FP4: - for hidden_pad, intermediate_pad in args.hidden_intermediate_pad: - for m in args.tokenNum: - yield _kw( - dtype, - m, - model_dim, - inter_dim, - quant_type, - aq_dtype, - wq_dtype, - doweight_stage1, - aiter.ActivationType.Swiglu, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - ), extras - elif triple == _PER1X32_FP8_FP4: - for hidden_pad, intermediate_pad in args.hidden_intermediate_pad: - for act_type in args.act: - for m in args.tokenNum: - yield _kw( - dtype, - m, - model_dim, - inter_dim, - quant_type, - aq_dtype, - wq_dtype, - doweight_stage1, - act_type, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - ), extras - elif triple == _PER1X32_FP4_FP4: - for preshuffle in args.preshuffle: - for act_type in args.act: - for m in args.tokenNum: - yield _kw( - dtype, - m, - model_dim, - inter_dim, - quant_type, - aq_dtype, - wq_dtype, - doweight_stage1, - act_type, - preshuffle=preshuffle, - hidden_pad=0, - intermediate_pad=0, - ), extras - elif triple == _PER1X32_BF16_I4: - for m in args.tokenNum: - yield _kw( - dtype, - m, - model_dim, - inter_dim, - quant_type, - aq_dtype, - wq_dtype, - doweight_stage1, - aiter.ActivationType.Silu, - ), extras - else: - for act_type in args.act: - for m in args.tokenNum: - yield _kw( - dtype, - m, - model_dim, - inter_dim, - quant_type, - aq_dtype, - wq_dtype, - doweight_stage1, - act_type, - ), extras + preshuffles = ( + args.preshuffle if args.preshuffle is not None else cfg["preshuffles"] + ) + E = args.expert if args.expert is not None else cfg["E"] + topk = args.topk if args.topk is not None else cfg["topk"] + # Kernel-forced acts always win over CLI -a. + acts = cfg["acts"] if cfg["acts"] is not None else args.act + + for ( + (model_dim, inter_dim), + (hidden_pad, intermediate_pad), + preshuffle, + act_type, + m, + ) in itertools.product(dims, pads, preshuffles, acts, args.tokenNum): + yield ( + dict( + dtype=dtype, + token=m, + model_dim=model_dim, + inter_dim=inter_dim, + E=E, + topk=topk, + actType=act_type, + gateMode=_effective_gate_mode(aq_dtype, wq_dtype), + qType=quant_type, + AQDType=aq_dtype, + WQDType=wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + preshuffle=preshuffle, + strict_accuracy=cfg["strict_accuracy"], + check_aot_cache=False, + ), + extras, + ) # ---------------------------------------------------------------------------