Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 195 additions & 131 deletions op_tests/test_moe_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,17 @@ def weight_per_128x128_quant(weight, quant_dtype):
num_iters=5,
num_warmup=2,
)
# Compare only the real (un-padded) model_dim region. The padded tail
# (last hidden_pad cols) carries no meaningful output and some kernels
# leave it uninitialized/NaN, which would poison checkAllclose/logits_diff.
real_model_dim = model_dim - hidden_pad
out2_ref = out2_ref[:, :real_model_dim]
out2_ck = out2_ck[:, :real_model_dim]
Comment on lines +372 to +374

err = checkAllclose(
out2_ref,
out2_ck,
msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})",
msg=f"ck_moe_2stages:{us2:>8.2f} us, {token * model_dim * inter_dim * 3 * topk * 2 / us2 / 1000 / 1000:>8.2f} tflops......(quant:{AQDType})",
)

def calc_diff(x: torch.Tensor, y: torch.Tensor):
Expand Down Expand Up @@ -431,8 +438,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
"-dim",
type=dtypes.str2tuple,
nargs="*",
default=[(7168, 256)],
help="""Model dimension.
default=None,
help="""Model dimension. If unset, each quant uses its per-quant default
from QUANT_DEFAULTS; if set, the value applies to ALL selected quants.
e.g.: -dim 6144,4096""",
)

Expand Down Expand Up @@ -502,17 +510,17 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
"-e",
"--expert",
type=int,
default=257,
help="""Number of experts.
default=None,
help="""Number of experts. If unset, each quant uses its per-quant default.
e.g.: -e 8""",
)

parser.add_argument(
"-k",
"--topk",
type=int,
default=9,
help="""Number of top experts.
default=None,
help="""Number of top experts. If unset, each quant uses its per-quant default.
e.g.: -k 2""",
)

Expand All @@ -521,8 +529,10 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
"--preshuffle",
type=dtypes.str2bool,
nargs="*",
default=[True],
help="""Whether to use pre-shuffle weight mode. Default is [False, True].
default=None,
help="""Whether to use pre-shuffle weight mode. If unset, each quant uses
its per-quant default (only a4w4 varies preshuffle; others require shuffled
weights for correctness).
Comment on lines +533 to +535
-p f # False.
-p t # True.""",
)
Expand All @@ -531,8 +541,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
"--hidden_intermediate_pad",
type=dtypes.str2tuple,
nargs="*",
default=[(192, 128)],
help="""Hidden intermediate pad.
default=None,
help="""Hidden intermediate pad. If unset, each quant uses its per-quant default.
e.g.: -hip 0,0""",
)
parser.add_argument(
Expand Down Expand Up @@ -665,10 +675,13 @@ def _iter_csv_cases():
continue
kwargs["strict_accuracy"] = True
kwargs["check_aot_cache"] = True
yield kwargs, {
"kernelName1": kernel_name1,
"kernelName2": kernel_name2,
}
yield (
kwargs,
{
"kernelName1": kernel_name1,
"kernelName2": kernel_name2,
},
)


_PER1X32_BF16_FP4 = (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2)
Expand All @@ -677,6 +690,118 @@ def _iter_csv_cases():
_PER1X32_BF16_I4 = (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.i4x2)


# Per-quant defaults for the legacy CLI-driven sweep. Each entry pins a
# representative shape taken from a tuned config so the smoke test exercises
# the kernel path that actually runs in production. CLI flags
# (-dim/-e/-k/-hip/-p) override these defaults globally when supplied.
#
# Fields:
# dims: list of (model_dim, inter_dim); CLI -dim overrides
# E: expert count; CLI -e overrides
# topk: topk; CLI -k overrides
# pads: list of (hidden_pad, intermediate_pad); CLI -hip overrides
# preshuffles: list of bool; CLI -p overrides
# acts: None = use CLI -a (defaults to [Silu]); explicit list = kernel
# forces this activation and -a cannot override
# strict_accuracy: if True, accuracy mismatch asserts; otherwise warn-only.
# Enabled for the two fp4-weight a4w4 / a8w4-mxfp paths whose
# accuracy we actively gate; other quants only warn.
QUANT_DEFAULTS = {
# a16w16, dsv3 expert MLP
(aiter.QuantType.No, None, None): {
"dims": [(2048, 192)],
"E": 128,
"topk": 8,
"pads": [(0, 0)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": False,
},
# a8w8 perTensor, dsv3
(aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8): {
"dims": [(7168, 256)],
"E": 256,
"topk": 8,
"pads": [(0, 0)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": False,
},
# a8w8 perToken, dsv3
(aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8): {
"dims": [(5120, 1024)],
"E": 128,
"topk": 1,
"pads": [(0, 0)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": False,
},
# a8w4 (int4 weight quant) — no tuned CSV row; keep legacy main default
(aiter.QuantType.per_Token, dtypes.fp8, torch.int4): {
"dims": [(3072, 512)],
"E": 128,
"topk": 4,
"pads": [(0, 0)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": False,
},
# a4w4, gptoss
_PER1X32_FP4_FP4: {
"dims": [(3072, 512)],
"E": 128,
"topk": 4,
"pads": [(192, 152)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": True,
},
# a8w8 blockscale (per_128x128 == CSV per_1x128), minimax
(aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8): {
"dims": [(3072, 1536)],
"E": 256,
"topk": 8,
"pads": [(0, 0)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": False,
},
# a16w4, gpt-oss fp8fp4 shape (a16w4 falls into the same dispatch
# for small-batch Swiglu); kernel forces Swiglu.
_PER1X32_BF16_FP4: {
"dims": [(3072, 3072)],
"E": 128,
"topk": 4,
"pads": [(192, 192)],
"preshuffles": [True],
"acts": [aiter.ActivationType.Swiglu],
"strict_accuracy": False,
},
# a8w4 mxfp (mxfp8 act + mxfp4 weight), dsv4
_PER1X32_FP8_FP4: {
"dims": [(7168, 512)],
"E": 385,
"topk": 7,
"pads": [(128, 128)],
"preshuffles": [True],
"acts": None,
"strict_accuracy": True,
},
# a16wi4 (packed int4 weight), kimik2_i4; kernel only supports Silu and
# the FlyDSL a16wi4 path does not wire pad parameters.
_PER1X32_BF16_I4: {
"dims": [(7168, 256)],
"E": 384,
"topk": 8,
"pads": [(0, 0)],
"preshuffles": [True],
"acts": [aiter.ActivationType.Silu],
"strict_accuracy": False,
},
}


def _effective_gate_mode(aq_dtype, wq_dtype):
if aq_dtype in [dtypes.fp8, dtypes.bf16] and wq_dtype == dtypes.fp4x2:
return GateMode.INTERLEAVE.value
Expand Down Expand Up @@ -710,126 +835,65 @@ def _runtime_swiglu_mxfp4_q_dtype_a(


def _iter_legacy_cases():
"""Yield (kwargs, extras) for the original CLI-driven sweep."""
"""Yield (kwargs, extras) for the per-quant smoke sweep.

Each selected quant uses its `QUANT_DEFAULTS` entry for shape / E / topk /
pad / preshuffle / act unless overridden by CLI flags. Quants whose
`acts` entry is an explicit list have a kernel-imposed activation
restriction (a16w4 → Swiglu, a16wi4 → Silu) and ignore `-a`.
"""
extras = {"model": "legacy"}

def _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
act_type,
**over,
for dtype, triple, doweight_stage1 in itertools.product(
args.dtype, l_quant, args.doweight_stage1
):
return dict(
dtype=dtype,
token=m,
model_dim=model_dim,
inter_dim=inter_dim,
E=args.expert,
topk=args.topk,
actType=act_type,
gateMode=_effective_gate_mode(aq_dtype, wq_dtype),
qType=quant_type,
AQDType=aq_dtype,
WQDType=wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
strict_accuracy=False,
check_aot_cache=False,
**over,
cfg = QUANT_DEFAULTS[triple]
quant_type, aq_dtype, wq_dtype = triple

dims = args.dim if args.dim is not None else cfg["dims"]
pads = (
args.hidden_intermediate_pad
if args.hidden_intermediate_pad is not None
else cfg["pads"]
)

for (
dtype,
(quant_type, aq_dtype, wq_dtype),
(model_dim, inter_dim),
doweight_stage1,
) in itertools.product(args.dtype, l_quant, args.dim, args.doweight_stage1):
triple = (quant_type, aq_dtype, wq_dtype)

if triple == _PER1X32_BF16_FP4:
for hidden_pad, intermediate_pad in args.hidden_intermediate_pad:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
aiter.ActivationType.Swiglu,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
), extras
elif triple == _PER1X32_FP8_FP4:
for hidden_pad, intermediate_pad in args.hidden_intermediate_pad:
for act_type in args.act:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
act_type,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
), extras
elif triple == _PER1X32_FP4_FP4:
for preshuffle in args.preshuffle:
for act_type in args.act:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
act_type,
preshuffle=preshuffle,
hidden_pad=0,
intermediate_pad=0,
), extras
elif triple == _PER1X32_BF16_I4:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
aiter.ActivationType.Silu,
), extras
else:
for act_type in args.act:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
act_type,
), extras
preshuffles = (
args.preshuffle if args.preshuffle is not None else cfg["preshuffles"]
)
E = args.expert if args.expert is not None else cfg["E"]
topk = args.topk if args.topk is not None else cfg["topk"]
# Kernel-forced acts always win over CLI -a.
acts = cfg["acts"] if cfg["acts"] is not None else args.act

for (
(model_dim, inter_dim),
(hidden_pad, intermediate_pad),
preshuffle,
act_type,
m,
) in itertools.product(dims, pads, preshuffles, acts, args.tokenNum):
yield (
dict(
dtype=dtype,
token=m,
model_dim=model_dim,
inter_dim=inter_dim,
E=E,
topk=topk,
actType=act_type,
gateMode=_effective_gate_mode(aq_dtype, wq_dtype),
qType=quant_type,
AQDType=aq_dtype,
WQDType=wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
preshuffle=preshuffle,
strict_accuracy=cfg["strict_accuracy"],
check_aot_cache=False,
),
extras,
)


# ---------------------------------------------------------------------------
Expand Down
Loading