perf: add preshuffle logic for HIP fused_moe#1057
Open
ftyghome wants to merge 1 commit into
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Updates MoE post-load processing to shuffle weights/scales using an explicit layout and records shuffle metadata for downstream logic.
Changes:
- Shuffles
w13/w2weights viashuffle_weight(..., layout=(16,16), gate_up=...). - Shuffles
w13/w2scales viashuffle_scale(...)with inferred expert count from the scale tensor shape. - Adds
is_shuffledandshuffle_kind = "mxfp4_moe"metadata on weights and scales.
Comments suppressed due to low confidence (1)
atom/model_ops/moe.py:1
- This block is a full copy of the previous implementation left commented out, which adds noise and risks divergence from the active logic. Remove it (or move it to commit history / a doc comment explaining the behavioral change if it’s needed for future reference).
# SPDX-License-Identifier: MIT
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+923
to
926
| E13, N13, K13 = layer.w13_weight_scale.data.shape | ||
| layer.w13_weight.data = shuffle_weight( | ||
| layer.w13_weight, | ||
| is_guinterleave=self.is_guinterleave, | ||
| gate_up=True, | ||
| layer.w13_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=True | ||
| ) |
Comment on lines
+934
to
937
| E2, N2, K2 = layer.w2_weight_scale.data.shape | ||
| layer.w2_weight.data = shuffle_weight( | ||
| layer.w2_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=False | ||
| ) |
Comment on lines
+923
to
943
| E13, N13, K13 = layer.w13_weight_scale.data.shape | ||
| layer.w13_weight.data = shuffle_weight( | ||
| layer.w13_weight, | ||
| is_guinterleave=self.is_guinterleave, | ||
| gate_up=True, | ||
| layer.w13_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=True | ||
| ) | ||
| layer.w2_weight.data = shuffle_weight( | ||
| layer.w2_weight, | ||
| is_guinterleave=self.is_guinterleave, | ||
| gate_up=False, | ||
| ) | ||
| layer.w13_weight.is_shuffled = True | ||
| layer.w2_weight.is_shuffled = True | ||
|
|
||
| # shuffle scale | ||
| w13_scale_2d = layer.w13_weight_scale.reshape( | ||
| -1, layer.w13_weight_scale.shape[-1] | ||
| layer.w13_weight_scale = atom_parameter( | ||
| shuffle_scale( | ||
| layer.w13_weight_scale.data.reshape(E13 * N13, K13), | ||
| experts_cnt=E13, is_guinterleave=True, gate_up=True, | ||
| ).reshape(E13, N13, K13) | ||
| ) | ||
| w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1]) | ||
|
|
||
| shuffled_w13_scale = shuffle_scale( | ||
| w13_scale_2d, self.num_experts, self.is_guinterleave, True | ||
| E2, N2, K2 = layer.w2_weight_scale.data.shape | ||
| layer.w2_weight.data = shuffle_weight( | ||
| layer.w2_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=False | ||
| ) | ||
| shuffled_w2_scale = shuffle_scale( | ||
| w2_scale_2d, self.num_experts, self.is_guinterleave, False | ||
| layer.w2_weight_scale = atom_parameter( | ||
| shuffle_scale( | ||
| layer.w2_weight_scale.data.reshape(E2 * N2, K2), | ||
| experts_cnt=E2, is_guinterleave=True, gate_up=False, | ||
| ).reshape(E2, N2, K2) | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
ROCm/aiter#3470 adds
mxfp4_moe, a native HIP MXFP4 (a4w4) MoE backend for gfx950 that beats the FlyDSL path. To use it, weights must be shuffled into the layout its kernels expect;fused_moedispatches byshuffle_kind, so old-layout weights never reach the HIP backend.Technical Details
Add a preshuffle pass for MXFP4 MoE weights (packed fp4 + e8m0 group scales) matching aiter's
mxfp4_moelayout, applied at model load and gated on gfx950 + an aiter build with #3470. Tagging weights with the matchingshuffle_kindis enough forfused_moeto auto-dispatch to the HIP backend.Test Plan
See ROCm/aiter#3470.
Test Result
See ROCm/aiter#3470.