Skip to content

fix: enable TBO compute/comm overlap in Deepseek V4 and fix mooncake gather hang for PD prefill#1121

Open
ZhangLirong-amd wants to merge 2 commits into
mainfrom
zlr/prefill-routing-side-stream-overlap
Open

fix: enable TBO compute/comm overlap in Deepseek V4 and fix mooncake gather hang for PD prefill#1121
ZhangLirong-amd wants to merge 2 commits into
mainfrom
zlr/prefill-routing-side-stream-overlap

Conversation

@ZhangLirong-amd
Copy link
Copy Markdown
Contributor

@ZhangLirong-amd ZhangLirong-amd commented Jun 7, 2026

Summary

Two fixes for DeepSeek-V4-Pro PD-disaggregation prefill with --enable-tbo:

  1. Enable real compute/comm overlap.

  2. Fix mooncake gather hang under high-concurrency TBO prefill.

Trace results

DeepSeek-V4-Pro, TP=8, --enable-dp-attention --enable-tbo, c=256, ISL=8192:

TBO NCCL overlap with compute
before 3.8 %
after 91.7 %
image

Test plan

  • PD prefill bench (c=256, ISL=8192, OSL=1024) runs to completion without hang
  • Trace shows 91.7 % of TBO NCCL overlapping with compute on the new node
  • CI: existing model accuracy tests

Without this sync, the RDMA write can race the still-in-flight gather
kernel that fills the staging buffer, causing GPU page faults under
high-concurrency TBO prefill (DeepSeek-V4-Pro, mesh router, c=256).
Copilot AI review requested due to automatic review settings June 7, 2026 09:54
@ZhangLirong-amd ZhangLirong-amd changed the title fix: enable TBO compute/comm overlap and fix mooncake gather hang for PD prefill fix: enable TBO compute/comm overlap in Deepseek V4 and fix mooncake gather hang for PD prefill Jun 7, 2026
@ZhangLirong-amd ZhangLirong-amd force-pushed the zlr/prefill-routing-side-stream-overlap branch from 42be4f3 to 62031bb Compare June 7, 2026 09:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets DeepSeek-V4-Pro PD-disaggregation prefill stability and performance under TBO by (1) moving specific DP/TP collectives off the main compute stream to improve true compute/comm overlap, and (2) preventing a Mooncake RDMA race by synchronizing after staging-buffer gathers.

Changes:

  • Added a per-device auxiliary CUDA stream helper and routed DP input_ids all-gather + MoE combine_outputs TP all-reduce through it to avoid NCCL interleaving on the compute lane during TBO ping-pong.
  • Added an explicit CUDA stream synchronization after the staging-buffer gather in Mooncake block+slot transfers to prevent GPUDirect RDMA reads racing in-flight gathers.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
atom/models/deepseek_v4.py Adds _run_on_tbo_comm_stream and uses it for DP id all-gather + TP all-reduce to improve TBO overlap behavior.
atom/kv_transfer/disaggregation/mooncake/mooncake_connector.py Synchronizes after _gather_slot(...) to avoid RDMA reading an incompletely-populated staging buffer.
Comments suppressed due to low confidence (1)

atom/models/deepseek_v4.py:54

  • _run_on_tbo_comm_stream introduces torch.cuda.stream(...) / wait_stream(...) control flow into the eager Python graph. DeepseekV4Model is @support_torch_compile, and when MoE dual-stream is disabled (common when shared experts are fused), combine_outputs() executes in the compiled forward path. Dynamo/Inductor generally cannot trace stream-context switching reliably, which can break compilation or produce unexpected graph breaks.

Consider explicitly marking this helper as non-compilable so the rest of the model can still compile, while the stream manipulation runs eagerly at runtime.

    tensor_model_parallel_all_reduce,
)
from aiter.dist.parallel_state import (
    get_tensor_model_parallel_world_size,
)
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.topk import top_k_per_row_decode, top_k_per_row_prefill
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
from aiter.ops.triton.fusions.fused_clamp_act_mul import (
    fused_clamp_act_mul,
)
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits
from atom.config import (

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

The DP input_ids all-gather (hoisted into DeepseekV4ForCausalLM.forward
by #1109) and the MoE combine_outputs TP all-reduce both run on the same
stream as the main TBO compute kernels. With NCCL interleaved on that
lane the kernel queue serializes compute and NCCL, blocking the
hardware-level overlap with TBO comm_stream during ping-pong.

Move both collectives onto a per-device auxiliary stream with wait_stream
sync at both ends so the main compute lane stays free of NCCL
interleaving and can overlap with TBO comm_stream NCCL.

Trace results on DeepSeek-V4-Pro (TP=8 --enable-dp-attention --enable-tbo,
c=256 ISL=8192):
  before: 3.8 % of TBO NCCL overlaps with compute
  after:  91.7 % of TBO NCCL overlaps with compute
Copilot AI review requested due to automatic review settings June 7, 2026 10:03
@ZhangLirong-amd ZhangLirong-amd force-pushed the zlr/prefill-routing-side-stream-overlap branch from 62031bb to be7bda7 Compare June 7, 2026 10:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

Comment on lines +127 to +132
main = torch.cuda.current_stream()
side.wait_stream(main)
with torch.cuda.stream(side):
result = fn(*args, **kwargs)
main.wait_stream(side)
return result
Comment on lines 1060 to +1064
self._gather_slot(src_slot, producer_pool_idx)
# Synchronize on the gather kernel before NIC starts reading the
# staging buffer. Without this, the RDMA can race the still-in-flight
# gather kernel on TBO prefill (page fault under high concurrency).
torch.cuda.current_stream().synchronize()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants