Skip to content

EP+pad support for Step-3.5-Flash-FP8#1091

Draft
LJ-underdog wants to merge 10 commits into
mainfrom
feat-ep-pad-clean
Draft

EP+pad support for Step-3.5-Flash-FP8#1091
LJ-underdog wants to merge 10 commits into
mainfrom
feat-ep-pad-clean

Conversation

@LJ-underdog
Copy link
Copy Markdown

@LJ-underdog LJ-underdog commented Jun 5, 2026

Motivation

Enable end-to-end serving of stepfun Step-3.5-Flash-FP8 (FP8 block-quantized MoE) on ATOM at TP8 on AMD gfx942 (MI308X), across the two production parallelism paths: Expert-Parallel (EP) and pure Tensor-Parallel with padding.

Technical Details

This branch adds the support required to load and run Step-3.5-Flash-FP8:

  • Step-3.5-Flash model support: new step3p5 model definition and MoE weight shuffling for the routed-expert layout.
  • FP8 block-quantized MoE inference: per-1x128 block scaling (block_shape = [128, 128]), with correct inter_dim padding-alignment across all TP configurations, plumbed through Fp8MoEMethod.get_fused_moe_quant_config.
  • TP8 FP8 weight loading: handle D < tp_size in the fp8 _load_w13 / _load_w2 shard loaders so the routed-expert weights load correctly at TP8.
  • Sliding-window attention (SWA) per-layer KV-head workspace: Step-3.5-Flash enables sliding_window. ModelConfig.get_num_kv_heads() returns the model-level value (32), but the per-layer attention modules use 4 KV heads; the SWA workspace is now sized from the true per-layer num_kv_heads (max across layers), so the workspace head dimension is correct. Touches only atom/plugin/attention.py.
  • Two production parallelism paths:
    • EP (--enable-expert-parallel): experts sharded, inter_dim = 1280.
    • Pure-TP + pad: inter_dim sharded 1280/8 = 160, padded to 256.

Test Plan

End-to-end correctness via the ATOM-native simple_inference example on 8×gfx942 (TP8, FP8 block-quantized MoE, FP8 KV cache), over the example's 4 prompts, for both production paths:

  • EP: --enable-expert-parallel.
  • Pure-TP + pad: default (no expert-parallel).

Test Result

Both paths 4/4 coherent (exit 0), non-garbled, with natural EOS and no faults:

  • EP (inter=1280): 4/4 prompts coherent; arithmetic correct (1 + 2 + 3 = 6).
  • Pure-TP + pad (inter=256): 4/4 prompts coherent; Engine Core fully initialized (44/44 shards); arithmetic correct (1 + 2 + 3 = 6).
  • No GPU faults; VRAM returned cleanly to baseline after each run (no leak).

Submission Checklist

  • I have read and followed the contributing guidelines.
  • This PR targets the default integration branch (main).
  • The code builds successfully.
  • I have included the log of a successful test run (EP + pure-TP-pad e2e, 4/4 coherent — logs available on request).
  • This PR does not break existing test cases.
  • Dependent changes are identified below (aiter production PR + CK pin).

Related

LJ-underdog and others added 10 commits April 23, 2026 21:20
Add Step3p5ForCausalLM model support for the Step-3.5-Flash architecture,
and fix a critical MoE correctness bug on gfx950 (MI350X).

Core MoE fix (atom/model_ops/moe.py):
  Previously skipped shuffle_weights() for gfx950 BF16 g1u1 based on the
  incorrect assumption that the CK 2-stage preshuffle_off (NSwizzle=0)
  kernel expects un-shuffled weights. Verified: preshuffle_off GEMM is
  wrong on gfx950; preshuffle_on (NSwizzle=1) is correct. Always call
  shuffle_weights() so the correct kernel path is selected.

Step-3.5-Flash model support (atom/models/step3p5.py):
  - Mixed full/sliding window attention (per layer_types config)
  - 288 routed + 1 shared expert MoE with sigmoid routing
  - Per-layer SwigluStep activation: layers with swiglu_limits[i]>0 use
    ActivationType.SwigluStep (CK kernel applies silu(g).clamp(7)*up.clamp(±7));
    other layers use plain Silu. Shared expert at SwigluStep layers is kept
    on the dense MLP path (kernel clamp is routed-expert-only).
  - Fused expert loading (flat [E,I,H] checkpoint format)
  - clamp_limit applied to dense MLP and shared expert via Step3p5MLP

atom/model_engine/model_runner.py:
  - Register Step3p5ForCausalLM architecture
  - Handle num_attention_groups config key (Step-3.5 uses this instead of
    num_key_value_heads) in KV head count calculations

atom/model_loader/loader.py:
  - Fix fused expert detection order: check before packed_modules_mapping
    to prevent moe.gate_proj being matched as gate_up_proj

atom/model_ops/attentions/aiter_attention.py:
  - Handle num_attention_groups config key for KV head count

atom/examples/simple_inference.py:
  - Add --max-tokens arg and trust_remote_code support

Verified: tp=2 Step-3.5-Flash inference, 4 prompts, no NaN/crash,
coherent output (with ATOM_STEP3P5_NO_SLIDING=1 workaround for
pa_decode_gluon bug on gfx950, tracked separately).

Co-Authored-By: Jun Lin <junlin12@amd.com>
CK 2-stage MoE kernel (gemm_moe_ck2stages.cu L98) computes stage1 N
as w1.size(1)/2 = inter_dim. The stage1 dispatch selects NPerBlock
based on inter_dim range:
  - inter <= 192: NPerBlock = 64  -> need inter % 64 == 0
  - inter >  192: NPerBlock = 128 -> need inter % 128 == 0

Step-3.5-Flash with tp=4 gives inter=320 (320%128=64 != 0, crash)
and with tp=8 gives inter=160 (160%64=32 != 0, crash).

Fix: in process_weights_after_loading, pad inter_dim before
shuffle_weights() using alignment = 64 if inter<=192 else 128:
  - inter=160 -> 192  (tp=8, 192%64=0)
  - inter=320 -> 384  (tp=4, 384%128=0, 384%64=0)

Zero-padding is safe: padded rows carry zero weight so contribute
nothing to fused_moe output.

Verified 2026-04-24 on gfx950 (MI350X):
  - cos_sim >= 0.9999 vs torch reference (M=1..256)
  - tp=4 inference: 4 prompts complete, no crash, output correct

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The else branch in get_fused_moe_quant_config was shared between
block_quant (per_1x128/per_1x32) and per_tensor paths, hardcoding
block_shape=None for all. Block-quantized FP8 models should receive
block_shape=[128,128] (per_1x128) or [1,32] (per_1x32) to correctly
configure the quant config, particularly for EP paths.

Split the else branch into explicit per_1x128/per_1x32/fallback cases
and unify the fp8_w8a8_moe_quant_config call.
Three coordinated fixes in Fp8MoEMethod for per_1x128 block scale:

1. create_weights: make ValueError check padding-aware
   Compute padded_inter = ceil(inter/block_n)*block_n and check against
   padded_inter instead of raw inter, allowing tp=4 (inter=320) to pass
   while preserving the guard for truly unaligned cases.

2. _process_block_quant: zero-pad weights before shuffle_weights
   After normalize and before shuffle, zero-pad w13 from [E,2*320,H] to
   [E,2*384,H] and w2 from [E,H,320] to [E,H,384], mirroring the BF16
   approach in UnquantizedFusedMoEMethod.process_weights_after_loading.
   Padding zeros contribute 0 to GEMM output (dequant(0, scale)=0).
   Scale tensors already use ceil(inter/block_n) and need no change.

3. _load_w13 / _load_w2: fix scale TP sharding floor→ceil (root cause)
   The per_1x128 scale for full inter=1280 has 10 N-blocks. TP=4 sharding
   with floor gives 10//4=2 blocks per rank; the 3rd (partial) block is
   never copied and stays at the torch.ones() init value of 1.0.
   With scale=1.0 instead of ~0.0002, dequant amplifies by ~5000×
   causing complete garbage output despite correct weight loading.
   Fix: use ceil division and add narrow() bounds protection for the
   last rank which may have fewer elements than the ceil size.
   Safe for tp=2 (10/2=5 exact, ceil==floor) and tp=1 (no sharding).

Verification:
  FP8 tp=4: 4 prompts, TTFT=92ms, TPOT=14ms, coherent output ✅
  BF16 tp=4 regression: TTFT=76-77ms, coherent output ✅
  FP8 tp=2 regression: TTFT=86ms, coherent output ✅
…ding

With NPerBlock=64 CK kernel support, inter_dim=320 (tp=4) is 64-aligned
and no longer requires zero-padding to 384. Changed align from
'64 if inter<=192 else block_n' to always 64, so:
- tp=4 (inter=320): 320%64=0 -> no padding (was 320->384, saved 17% compute)
- tp=8 (inter=160): 160%64=32 -> pad to 192 (unchanged)
- tp=2 (inter=640): 640%64=0 -> no padding (unchanged)

Scale tensor shape (ceil(320/128)=3) unchanged; no re-quantization needed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Stage2 KPerBlock=64 is not compilable on gfx950 (FP8 mfma KPack=32
constraint). Since stage1 output and stage2 weight K must match,
both w13 and w2 require the same inter_dim padding. Restoring:
  align = 64 if inter_dim <= 192 else block_n (=128)

Added comment explaining why full no-padding is currently blocked.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…onfigs

_process_block_quant used 'align = 64 if inter_dim <= 192 else block_n',
copied from the BF16 path. For FP8 blockscale this is wrong:
- FP8 stage2 only has KPerBlock=128 (KPack=32 mfma constraint prevents KPerBlock=64)
- align=64 gives inter_pad=192 for tp=8 (inter=160), but 192 % 128 = 64 != 0
- device_moe_gemm_blockscale.hpp L448 rejects K % KPerBlock != 0 → kernel fails

Fix: always use align = block_n (=128 for per_1x128), so inter_pad is always
a multiple of 128 and stage2 KPerBlock=128 dispatch succeeds:
  tp=2: inter=640 → 640 (no padding, unchanged)
  tp=4: inter=320 → 384 (unchanged)
  tp=8: inter=160 → 256 (was 192, now correctly aligned)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the per_1x128 scale block count is smaller than tp_size (observed
on Step-3.5-Flash-FP8 at tp=8 with inter_dim=1280 → D=10), the ceil
split leaves trailing ranks with start >= D so narrow(start, size) hits
size<0 and crashes weight load. Skip narrow + copy_ for those ranks.

For fp8 scale tensors (torch.ones() initialised in
Fp8MoEMethod._create_weights), additionally zero the rank's slot before
the early return. Otherwise the downstream fp8 dequant multiplies the
(uninitialised) fp8 weight by stale 1.0 instead of the correct
quantization scale, contaminating the column gather / row reduction and
producing garbled output. Matches MXFP4 scale init (moe.py:776,813).

Verified on stepfun-ai/Step-3.5-Flash-FP8 (gfx942 / MI308X):
- tp=8 A1/A2/A4 PASS — 4/4 prompts coherent (was: weight-load crash
  pre-patch; was: garbled output with early-return-only)
- tp=2/tp=4 A1/A2/A3 PASS — no regression, zero-trigger confirmed
  (D=10, starts=[0,3,6,9] for tp=4, starts=[0,5] for tp=2 — all < D)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- black: reformat 6 PR-touched files
- ruff (atom/models/step3p5.py): drop unused `Any` import,
  un-f-string two no-placeholder debug prints, remove unused
  `clamp_limit` and `swiglu_limits` locals in
  `Step3p5DecoderLayer.__init__`
…ds_kv 32 -> per-layer 4); fixes EP/TP e2e on PR#641
@LJ-underdog LJ-underdog changed the title Enable Step3.5 FP8 EP+pad support for Step-3.5-Flash-FP8 Jun 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants