Skip to content

perf: add preshuffle logic for HIP fused_moe#1057

Open
ftyghome wants to merge 1 commit into
ROCm:mainfrom
RadeonFlow:rf-mxfp4-moe
Open

perf: add preshuffle logic for HIP fused_moe#1057
ftyghome wants to merge 1 commit into
ROCm:mainfrom
RadeonFlow:rf-mxfp4-moe

Conversation

@ftyghome
Copy link
Copy Markdown

@ftyghome ftyghome commented Jun 3, 2026

Motivation

ROCm/aiter#3470 adds mxfp4_moe, a native HIP MXFP4 (a4w4) MoE backend for gfx950 that beats the FlyDSL path. To use it, weights must be shuffled into the layout its kernels expect; fused_moe dispatches by shuffle_kind, so old-layout weights never reach the HIP backend.

Technical Details

Add a preshuffle pass for MXFP4 MoE weights (packed fp4 + e8m0 group scales) matching aiter's mxfp4_moe layout, applied at model load and gated on gfx950 + an aiter build with #3470. Tagging weights with the matching shuffle_kind is enough for fused_moe to auto-dispatch to the HIP backend.

Test Plan

See ROCm/aiter#3470.

Test Result

See ROCm/aiter#3470.

Copilot AI review requested due to automatic review settings June 3, 2026 13:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Updates MoE post-load processing to shuffle weights/scales using an explicit layout and records shuffle metadata for downstream logic.

Changes:

  • Shuffles w13/w2 weights via shuffle_weight(..., layout=(16,16), gate_up=...).
  • Shuffles w13/w2 scales via shuffle_scale(...) with inferred expert count from the scale tensor shape.
  • Adds is_shuffled and shuffle_kind = "mxfp4_moe" metadata on weights and scales.
Comments suppressed due to low confidence (1)

atom/model_ops/moe.py:1

  • This block is a full copy of the previous implementation left commented out, which adds noise and risks divergence from the active logic. Remove it (or move it to commit history / a doc comment explaining the behavioral change if it’s needed for future reference).
# SPDX-License-Identifier: MIT

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_ops/moe.py
Comment on lines +923 to 926
E13, N13, K13 = layer.w13_weight_scale.data.shape
layer.w13_weight.data = shuffle_weight(
layer.w13_weight,
is_guinterleave=self.is_guinterleave,
gate_up=True,
layer.w13_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=True
)
Comment thread atom/model_ops/moe.py
Comment on lines +934 to 937
E2, N2, K2 = layer.w2_weight_scale.data.shape
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=False
)
Comment thread atom/model_ops/moe.py
Comment on lines +923 to 943
E13, N13, K13 = layer.w13_weight_scale.data.shape
layer.w13_weight.data = shuffle_weight(
layer.w13_weight,
is_guinterleave=self.is_guinterleave,
gate_up=True,
layer.w13_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=True
)
layer.w2_weight.data = shuffle_weight(
layer.w2_weight,
is_guinterleave=self.is_guinterleave,
gate_up=False,
)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True

# shuffle scale
w13_scale_2d = layer.w13_weight_scale.reshape(
-1, layer.w13_weight_scale.shape[-1]
layer.w13_weight_scale = atom_parameter(
shuffle_scale(
layer.w13_weight_scale.data.reshape(E13 * N13, K13),
experts_cnt=E13, is_guinterleave=True, gate_up=True,
).reshape(E13, N13, K13)
)
w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1])

shuffled_w13_scale = shuffle_scale(
w13_scale_2d, self.num_experts, self.is_guinterleave, True
E2, N2, K2 = layer.w2_weight_scale.data.shape
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=False
)
shuffled_w2_scale = shuffle_scale(
w2_scale_2d, self.num_experts, self.is_guinterleave, False
layer.w2_weight_scale = atom_parameter(
shuffle_scale(
layer.w2_weight_scale.data.reshape(E2 * N2, K2),
experts_cnt=E2, is_guinterleave=True, gate_up=False,
).reshape(E2, N2, K2)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants