Skip to content

[Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model#1055

Open
qichu-yun wants to merge 5 commits into
mainfrom
atom_fuse
Open

[Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model#1055
qichu-yun wants to merge 5 commits into
mainfrom
atom_fuse

Conversation

@qichu-yun
Copy link
Copy Markdown
Contributor

@qichu-yun qichu-yun commented Jun 3, 2026

Motivation

Enable DeepSeek V2 FP4 input_rmsnorm_quant and qknorm_quant fusion in ATOM across both Triton and non-Triton GEMM paths. This reduces separate RMSNorm + quant kernel launches while keeping FP4 activation scale layouts and cached KV projection paths correct.

Technical Details

  • Enable qknorm quant fusion for FP8/FP4 without requiring Triton GEMM.
  • Route FP4 qknorm and input RMSNorm quant through fused RMSNorm+MXFP4 quant kernels.
  • Select the correct MXFP4 activation/scale layout for preshuffle Triton, non-shuffle Triton, and ASM GEMM paths.
  • Add FP4-safe cached KV projection gather fallback for MLA attention.
  • Align the SGLang MLA prepare path with the native ATOM qknorm quant fusion behavior.

Test Plan

server:

#!/bin/bash
export ATOM_USE_TRITON_MOE=1
export ATOM_ENABLE_DS_QKNORM_QUANT_FUSION=1
export ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION=1

rm -rf /root/.cache/ /tmp/torchinductor_root
model_path=/shared/data/amd_int/models/deepseek-ai/DeepSeek-R1-0528-MXFP4-v2
python -m atom.entrypoints.openai_server \
    --model $model_path \
    --host localhost \
    --server-port 8000 \
    --tensor-parallel-size 8 \
    --kv_cache_dtype fp8 \
    --gpu-memory-utilization 0.8 \
    --no-enable_prefix_caching

client:

#!/bin/bash
set -euo pipefail
 
addr=localhost
port=8000
url=http://${addr}:${port}/v1/completions
 
model_path="/shared/data/amd_int/models/deepseek-ai/DeepSeek-R1-0528-MXFP4-v2"
num_concurrent="${LM_EVAL_CONCURRENT:-16}"
max_gen_toks="${LM_EVAL_MAX_GEN_TOKS:-512}"


lm_eval --model local-completions \
    --model_args "{\"base_url\": \"${url}\", \"model\": \"${model_path}\", \"num_concurrent\": ${num_concurrent}, \"max_retries\": 1, \"max_gen_toks\": ${max_gen_toks}, \"tokenized_requests\": false}" \
    --tasks gsm8k \
    --batch_size auto \
    --num_fewshot 5 \
    --trust_remote_code \

Test Result

For preshuffle FP4 gemm kernel:
image

For non-preshuffle FP4 gemm kernel:
export ATOM_USE_FP4_NON_SHUFFLE_TRITON_GEMM=1
image

Accuracy:

Backend Model NON_SHUFFLE=0 NON_SHUFFLE=1 Delta
atom DeepSeek-R1-0528-MXFP4-v2 0.9439 0.9469 +0.0030
atom DeepSeek-R1-0528-MXFP4-MTP-MoEFP4 0.9507 0.9553 +0.0046
atom DeepSeek-R1-0528 0.9629 0.9629 +0.0000

Submission Checklist

@qichu-yun qichu-yun changed the title Atom fuse [Feat] Support input_rmsmorn_quant & qknorm_quant fused for Atom dpsk v2 model Jun 3, 2026
@qichu-yun qichu-yun changed the title [Feat] Support input_rmsmorn_quant & qknorm_quant fused for Atom dpsk v2 model [Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model Jun 3, 2026
Comment thread atom/model_ops/attention_mla.py Outdated
flat_cache = kv_cache.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)
# ``kv_indices`` is a reusable workspace; only the current k/v output
# length is valid for this call.
valid_indices = kv_indices[: k_out.shape[0]].to(torch.long)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have deleted it....

@qichu-yun qichu-yun changed the title [Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model [Draft][Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model Jun 3, 2026
@zhuyuhua-v zhuyuhua-v marked this pull request as draft June 3, 2026 12:56
@qichu-yun qichu-yun changed the title [Draft][Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model [WIP][Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model Jun 3, 2026
@qichu-yun qichu-yun marked this pull request as ready for review June 4, 2026 08:26
@qichu-yun qichu-yun changed the title [WIP][Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model [Feat] Support input_rmsnorm_quant & qknorm_quant fused for Atom dpsk v2 model Jun 4, 2026
qichu-yun added 2 commits June 4, 2026 04:40
… v2 model

Support DeepSeek FP4 qknorm fusion and scoped input RMSNorm quant fusion through fused RMSNorm+MXFP4 kernels while preserving cached KV fallback handling, GEMM scale layouts, and mixed-quant/MTP model stability.
Route SGLang FP4 qknorm quant through the same activation scale layout and q_b_proj backend dispatch as native ATOM to avoid precision loss when the non-shuffle Triton GEMM path is disabled.
qichu-yun added 3 commits June 7, 2026 22:06
Use the extended gather_kv_b_proj path for cached MLA prefill so FP4 weights no longer fall back to Python gather/projection.
Limit the newly enabled non-Triton FP4 input RMSNorm quant path to DeepSeek V2/V3 models so wrapper models keep their existing Triton-gated behavior.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants