Skip to content

[Feat] Support DSv3.2 with sglang-ATOM plugin and sglang sparse MLA#1094

Open
qichu-yun wants to merge 6 commits into
mainfrom
support_dsv3.2
Open

[Feat] Support DSv3.2 with sglang-ATOM plugin and sglang sparse MLA#1094
qichu-yun wants to merge 6 commits into
mainfrom
support_dsv3.2

Conversation

@qichu-yun
Copy link
Copy Markdown
Contributor

@qichu-yun qichu-yun commented Jun 5, 2026

Motivation

This PR adds DeepSeek-V3.2 support for the SGLang ATOM plugin with FP8 KV cache and sparse MLA enabled.

The goal is to make the SGLang path run DeepSeek-V3.2 correctly with its sparse MLA indexer, including prefill and decode, while keeping SGLang-specific logic isolated from the vLLM plugin and avoiding unnecessary changes to other model paths.

Technical Details

  • Adds SGLang plugin integration for DeepSeek-V3.2 sparse MLA and FP8 KV cache.
  • Introduces a SGLang-specific sparse MLA indexer backend under atom/plugin/sglang/attention_backend/sparse_mla_indexer.py.
  • Fixes sparse MLA prefill KV mapping for SGLang’s request/token table layout.
  • Fixes FP8 sparse MLA decode by converting sparse top-k request-local indices into global paged KV indices.
  • Patches the SGLang DeepSeek MLA path to allocate and pass the expected 2D top-k index buffer.
  • Keeps sparse MLA decode logic scoped to the SGLang sparse backend instead of the shared full-attention backend.
  • Adds the minimal RoPE writeback needed for the non-fused DeepSeek indexer path.
  • Preserves the existing SGLang plugin fallback for max_num_seqs when SGLang does not pass max_running_requests.

Test Plan

server:

export AITER_QUICK_REDUCE_QUANTIZATION=INT4
export SGLANG_AITER_FP8_PREFILL_ATTN=0
export SGLANG_USE_AITER=1
export ATOM_ENABLE_DS_QKNORM_QUANT_FUSION=1

model_path=/shared/data/amd_int/models/DeepSeek-V3.2

export SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models

TORCHINDUCTOR_COMPILE_THREADS=128 python3 -m sglang.launch_server \
    --model-path $model_path \
    --host localhost \
    --port 8000 \
    --trust-remote-code \
    --tp-size 8 \
    --mem-fraction-static 0.8 \
    --disable-radix-cache \
    --kv-cache-dtype fp8_e4m3 \

benchmark:

set -euo pipefail
 
MODEL=${MODEL:-/shared/data/amd_int/models/deepseek-ai/DeepSeek-R1-0528-MXFP4-v2}
BASE_URL=${BASE_URL:-http://localhost:9000}
 
RANGE_RATIO=${RANGE_RATIO:-1.0}
 
ISL=${ISL:-8192}
OSL=${OSL:-1}
CON=${CON:-64}
NUM=${NUM:-$(( CON * 3 ))}
RESULT_DIR=${RESULT_DIR:-./}
RESULT_FILENAME=${RESULT_FILENAME:-}
CONFIG_NAME=${CONFIG_NAME:-manual}
 
mkdir -p "${RESULT_DIR}"
 
result_filename_args=()
if [[ -n "${RESULT_FILENAME}" ]]; then
    result_filename_args+=(--result-filename="${RESULT_FILENAME}")
fi

metadata_args=(config=${CONFIG_NAME} isl=${ISL} osl=${OSL} concurrency=${CON})
if [[ -n "${SERVER_CONFIG:-}" ]]; then
    metadata_args+=(server_config=${SERVER_CONFIG})
fi
if [[ -n "${MODE:-}" ]]; then
    metadata_args+=(mode=${MODE})
fi
if [[ -n "${TOPOLOGY:-}" ]]; then
    metadata_args+=(topology=${TOPOLOGY})
fi

echo "Model=${MODEL}"
echo "ISL=${ISL}, OSL=${OSL}, NUM=${NUM}, CON=${CON} RANGE_RATIO=${RANGE_RATIO}"
echo "CONFIG_NAME=${CONFIG_NAME}, RESULT_DIR=${RESULT_DIR}, RESULT_FILENAME=${RESULT_FILENAME}"

sleep 2

echo "Starting bench with model: ${MODEL}"
python /shared/amdgpu/home/qichu_qle/qichu/bench_serving/benchmark_serving.py \
    --model=$MODEL \
    --backend=sglang \
    --base-url=${BASE_URL} \
    --dataset-name=random \
    --random-input-len=$ISL \
    --random-output-len=$OSL \
    --random-range-ratio ${RANGE_RATIO} \
    --num-prompts=${NUM} \
    --max-concurrency=${CON} \
    --num-warmups="$(( CON ))" \
    --request-rate=inf \
    --ignore-eos \
    --save-result \
    --metadata "${metadata_args[@]}" \
    --percentile-metrics="ttft,tpot,itl,e2el" \
    --result-dir=${RESULT_DIR} \
    "${result_filename_args[@]}"

Accuracy:

#!/bin/bash
set -euo pipefail
 
addr=localhost
port=8000
url=http://${addr}:${port}/v1/completions

model_path=/shared/data/amd_int/models/DeepSeek-V3.2
num_concurrent="${LM_EVAL_CONCURRENT:-64}"
max_gen_toks="${LM_EVAL_MAX_GEN_TOKS:-512}"

lm_eval --model local-completions \
    --model_args "{\"base_url\": \"${url}\", \"model\": \"${model_path}\", \"num_concurrent\": ${num_concurrent}, \"max_retries\": 1, \"max_gen_toks\": ${max_gen_toks}, \"tokenized_requests\": false}" \
    --tasks gsm8k \
    --batch_size auto \
    --num_fewshot 5 \
    --trust_remote_code \

Test Result

Accuracy Results

1.bf16 cache
image

2.--kv-cache-dtype fp8_e4m3
image

Performance Results

Input Len Output Len Concurrency Num Prompts Completed Duration (s) Req Throughput (req/s) Output Throughput (tok/s) Total Throughput (tok/s) Mean TTFT (ms) P99 TTFT (ms) Mean TPOT (ms) P99 TPOT (ms) Mean E2E (ms) P99 E2E (ms)
8192 1024 4 12 12 72.68 0.165 169.08 1521.73 2121.35 2267.85 21.61 22.03 24223.43 24241.20
8192 1024 16 48 48 108.53 0.442 452.88 4075.91 5900.30 8810.42 29.59 33.73 36170.59 36337.13
8192 1024 64 192 192 225.87 0.850 870.46 7834.10 19139.47 34960.45 54.87 71.61 75269.56 75533.80
1024 1024 4 12 12 62.58 0.192 196.37 392.74 412.34 415.15 19.99 20.03 20857.40 20896.80
1024 1024 16 48 48 75.33 0.637 652.45 1304.90 1098.67 1102.93 23.47 23.50 25107.53 25136.41
1024 1024 64 192 192 114.44 1.678 1717.94 3435.89 3335.43 4106.09 34.02 36.08 38137.20 38220.62

Submission Checklist

@qichu-yun qichu-yun marked this pull request as draft June 5, 2026 04:58
Keep DeepSeek-V3.2 sparse MLA decode logic in the SGLang sparse backend so the shared full-attention backend stays focused on common model paths.
@qichu-yun qichu-yun marked this pull request as ready for review June 5, 2026 08:52
Comment thread atom/models/deepseek_v2.py Outdated
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe)
q[..., : self.rope_dim] = q_pe
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need this? our rope are inplace op

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I check it. The current RoPE implementation uses rope_cached_positions_2c_fwd_inplace, and q_pe/k_pe are views into q/k, so the explicit writeback is indeed redundant. I removed these two lines. Thank you!

RoPE updates the sliced q/k tensors in place, so the explicit assignment back into the parent tensors is unnecessary.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants