Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 120 additions & 11 deletions examples/ops/dispatch_combine/test_dispatch_combine_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ def __init__(
quant_type="none",
dtype=torch.bfloat16,
hidden_dim=7168,
drop_rank=-1,
timeout_us=0,
):
self.rank = rank
self.gpu_per_node = gpu_per_node
self.world_size = world_size
self.drop_rank = drop_rank
self.timeout_us = timeout_us
self.config = mori.ops.EpDispatchCombineConfig(
data_type=dtype,
rank=self.rank,
Expand Down Expand Up @@ -100,6 +104,10 @@ def setup(self):
self.rng = torch.Generator(device=self.device)
self.rng.manual_seed(999)

self.active_ranks = torch.ones(
(self.world_size,), dtype=torch.int32, device=self.device
)

def cleanup(self):
mori.shmem.shmem_finalize()
dist.destroy_process_group()
Expand Down Expand Up @@ -256,6 +264,9 @@ def count_token_num(self, all_rank_indices):
)

for src_rank, indices in enumerate(all_rank_indices):
if src_rank == self.drop_rank:
continue

src_node = src_rank // self.config.gpu_per_node

# Map expert IDs to rank IDs
Expand Down Expand Up @@ -292,20 +303,63 @@ def count_token_num(self, all_rank_indices):
# print("Rank counts to other nodes:", rank_counts_remote_send)
return rank_counts, rank_counts_remote_recv, rank_counts_remote_send

def run_dispatch(self, op, token, weights, scales, indices):
def run_dispatch(self, op, token, weights, scales, indices, is_active=True):
kwargs = {}
if self.timeout_us > 0:
kwargs["active_ranks"] = self.active_ranks
kwargs["timeout_us"] = self.timeout_us

if not is_active:
# Simulate dropout by not calling dispatch
return (
torch.empty(
(0, self.config.hidden_dim),
dtype=self.config.data_type,
device=self.device,
),
torch.empty(
(0, op.config.num_experts_per_token),
dtype=torch.float32,
device=self.device,
),
torch.empty((0,), dtype=torch.float32, device=self.device),
torch.empty(
(0, op.config.num_experts_per_token),
dtype=torch.int32,
device=self.device,
),
torch.zeros((1,), dtype=torch.int32, device=self.device),
)

if op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL:
ret = op.dispatch_send(token, weights, scales, indices)
op.dispatch_recv()
else:
ret = op.dispatch(token, weights, scales, indices)
ret = op.dispatch(token, weights, scales, indices, **kwargs)
return ret
Comment on lines 334 to 339
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For AsyncLL kernel type, the elastic parameters are not passed to dispatch_send/dispatch_recv. While the AsyncLL kernels do support elastic EP (they have IsRankActive checks), the dispatch_send/combine_send/dispatch_recv/combine_recv methods already accept active_ranks and timeout_us parameters. Consider passing these parameters through the kwargs for AsyncLL as well to enable elastic EP support for this kernel type in the example.

Copilot uses AI. Check for mistakes.

def run_combine(self, op, token, weights, indices):
def run_combine(self, op, token, weights, indices, is_active=True):
kwargs = {}
if self.timeout_us > 0:
kwargs["active_ranks"] = self.active_ranks
kwargs["timeout_us"] = self.timeout_us

if not is_active:
# Simulate dropout by not calling combine
return (
torch.empty(
(indices.shape[0], self.config.hidden_dim),
dtype=self.config.data_type,
device=self.device,
),
None,
)

if op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL:
ret = op.combine_send(token, weights, indices)
op.combine_recv()
else:
ret = op.combine(token, weights, indices)
ret = op.combine(token, weights, indices, **kwargs)
Comment on lines 358 to +362
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to run_dispatch, for AsyncLL kernel type the elastic parameters are not passed to combine_send/combine_recv. Consider passing these parameters through the kwargs for AsyncLL as well to enable elastic EP support for this kernel type in the example.

Copilot uses AI. Check for mistakes.
return ret

def run_test_once(self, op, test_data, error_round, round):
Expand All @@ -317,6 +371,8 @@ def run_test_once(self, op, test_data, error_round, round):
all_rank_scales,
) = test_data

is_active = self.rank != self.drop_rank

(
dispatch_output,
dispatch_weights,
Expand All @@ -329,9 +385,13 @@ def run_test_once(self, op, test_data, error_round, round):
all_rank_weights[self.rank],
all_rank_scales[self.rank],
all_rank_indices[self.rank],
is_active=is_active,
)
torch.cuda.synchronize()

if not is_active:
return

rank_counts, _, _ = self.count_token_num(all_rank_indices)

src_token_pos = op.get_dispatch_src_token_pos().tolist()
Expand All @@ -351,6 +411,9 @@ def run_test_once(self, op, test_data, error_round, round):
for i, src_token_id in enumerate(src_token_pos):
src_pe = src_token_id // max_num_token_to_send_per_rank
src_tok_id = src_token_id % max_num_token_to_send_per_rank
assert (
src_pe != self.drop_rank
), f"Should not receive tokens from dropped rank {self.drop_rank}"
if self.config.data_type is torch.float4_e2m1fn_x2:
is_pass = torch.equal(
dispatch_output[i].view(torch.uint8),
Expand Down Expand Up @@ -393,11 +456,13 @@ def run_test_once(self, op, test_data, error_round, round):
(idx // self.config.num_experts_per_rank)
for idx in all_rank_indices[self.rank][i].cpu().tolist()
]
unique_pes = len(set(pes))

valid_pes = [p for p in pes if p != self.drop_rank]
unique_pes = len(set(valid_pes))
unique_innode_pes = len(
[
pe
for pe in set(pes)
for pe in set(valid_pes)
if (pe // self.gpu_per_node == self.rank // self.gpu_per_node)
]
)
Expand Down Expand Up @@ -505,6 +570,7 @@ def stress_dispatch_combine(self):
all_rank_weights,
all_rank_scales,
) = test_data_list[i % num_test_data]
is_active = self.rank != self.drop_rank
(
dispatch_output,
dispatch_weights,
Expand All @@ -517,9 +583,14 @@ def stress_dispatch_combine(self):
all_rank_weights[self.rank],
all_rank_scales[self.rank],
all_rank_indices[self.rank],
is_active=is_active,
)
combine_output, combine_output_weight = self.run_combine(
op, dispatch_output, None, all_rank_indices[self.rank]
op,
dispatch_output,
None,
all_rank_indices[self.rank],
is_active=is_active,
)
if i % sync_interval == 0:
torch.cuda.synchronize()
Expand All @@ -539,6 +610,7 @@ def stress_dispatch_combine(self):
all_rank_scales,
) = test_data
g = torch.cuda.CUDAGraph()
is_active = self.rank != self.drop_rank
with torch.cuda.graph(g):
(
dispatch_output,
Expand All @@ -552,9 +624,14 @@ def stress_dispatch_combine(self):
all_rank_weights[self.rank],
all_rank_scales[self.rank],
all_rank_indices[self.rank],
is_active=is_active,
)
combine_output, combine_output_weight = self.run_combine(
op, dispatch_output, None, all_rank_indices[self.rank]
op,
dispatch_output,
None,
all_rank_indices[self.rank],
is_active=is_active,
)
torch.cuda.synchronize()

Expand All @@ -577,7 +654,9 @@ def run_bench_once(self, max_num_token, op, test_data, repeat=10):
all_rank_scales,
) = test_data

is_active = self.rank != self.drop_rank
warmup_rounds = 3
total_recv_num_token = 0
for i in range(warmup_rounds):
(
dispatch_output,
Expand All @@ -591,13 +670,19 @@ def run_bench_once(self, max_num_token, op, test_data, repeat=10):
all_rank_weights[self.rank],
all_rank_scales[self.rank],
all_rank_indices[self.rank],
is_active=is_active,
)
if i == warmup_rounds - 1:
# Read totalRecvTokenNum after dispatch but before combine resets it
torch.cuda.synchronize()
total_recv_num_token = dispatch_recv_num_token[0].item()
if is_active:
total_recv_num_token = dispatch_recv_num_token[0].item()
combine_output, combine_output_weight = self.run_combine(
op, dispatch_output, None, all_rank_indices[self.rank]
op,
dispatch_output,
None,
all_rank_indices[self.rank],
is_active=is_active,
)
torch.cuda.synchronize()
total_rdma_recv_num_token = (
Expand Down Expand Up @@ -630,10 +715,15 @@ def run_bench_once(self, max_num_token, op, test_data, repeat=10):
all_rank_weights[self.rank],
all_rank_scales[self.rank],
all_rank_indices[self.rank],
is_active=is_active,
)
events[2 * i + 1].record()
combine_output, combine_output_weight = self.run_combine(
op, dispatch_output, None, all_rank_indices[self.rank]
op,
dispatch_output,
None,
all_rank_indices[self.rank],
is_active=is_active,
)
events[2 * i + 2].record()
torch.cuda.synchronize()
Expand Down Expand Up @@ -977,6 +1067,8 @@ def test_dispatch_combine(
quant_type="none",
cmd="test",
sweep_token_interval=64,
drop_rank=-1,
timeout_us=0,
):
world_size = num_node * gpu_per_node
node_rank = int(os.environ["RANK"])
Expand All @@ -992,6 +1084,9 @@ def test_dispatch_combine(
num_qp,
quant_type,
dtype,
hidden_dim=7168,
drop_rank=drop_rank,
timeout_us=timeout_us,
)
test_case.setup()
if cmd == "test":
Expand Down Expand Up @@ -1075,6 +1170,18 @@ def test_dispatch_combine(
"'fp8_direct_cast' is the current BF16<->FP8 direct cast path."
),
)
parser.add_argument(
"--drop-rank",
type=int,
default=-1,
help="Rank ID to simulate dropout to test elastic EP mechanism",
)
parser.add_argument(
"--timeout-us",
type=int,
default=0,
help="Timeout in microseconds for elastic EP mechanism polling",
)
args_cli = parser.parse_args()

if __name__ == "__main__":
Expand All @@ -1095,6 +1202,8 @@ def test_dispatch_combine(
args_cli.quant_type,
args_cli.cmd,
args_cli.sweep_token_interval,
args_cli.drop_rank,
args_cli.timeout_us,
),
nprocs=gpu_per_node,
join=True,
Expand Down
75 changes: 71 additions & 4 deletions include/mori/core/transport/p2p/device_primitives.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,9 @@ __forceinline__ __device__ void WarpCastBf16ToCombineInternalFp8(
// Note: when T != hip_bfloat16, this function is a no-op.
// Callers should guard with if constexpr or ensure T is hip_bfloat16.
#else
static_assert(!sizeof(T*), "WarpCastBf16ToCombineInternalFp8 requires FP8 type support "
"(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)");
static_assert(!sizeof(T*),
"WarpCastBf16ToCombineInternalFp8 requires FP8 type support "
"(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)");
#endif
}

Expand Down Expand Up @@ -1033,10 +1034,76 @@ __forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16(
// Note: when T != hip_bfloat16, this function is a no-op.
// Callers should guard with if constexpr or ensure T is hip_bfloat16.
#else
static_assert(!sizeof(T*), "WarpAccumCombineInternalFp8ToBf16 requires FP8 type support "
"(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)");
static_assert(!sizeof(T*),
"WarpAccumCombineInternalFp8ToBf16 requires FP8 type support "
"(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)");
#endif
}

/* ---------------------------------------------------------------------------------------------- */
/* Elastic EP Utilities */
/* ---------------------------------------------------------------------------------------------- */
inline __device__ bool IsRankActive(const int32_t* activeRanks, int rank) {
if (activeRanks == nullptr) return true;
return AtomicLoadRelaxedSystem(const_cast<int32_t*>(activeRanks) + rank) != 0;
}

inline __device__ void MarkRankInactive(int32_t* activeRanks, int rank) {
if (activeRanks == nullptr) return;
AtomicStoreRelaxedSystem(activeRanks + rank, int32_t{0});
}

inline __device__ void MarkRanksInactive(int32_t* activeRanks, int startRank, int count) {
if (activeRanks == nullptr) return;
for (int i = 0; i < count; ++i) {
AtomicStoreRelaxedSystem(activeRanks + startRank + i, int32_t{0});
}
}

template <typename T>
inline __device__ T WaitUntilGreaterThanOrTimeoutSystem(T* addr, T val, int64_t timeoutTicks,
int32_t* activeRanks, int watchedRank) {
if (timeoutTicks < 0) {
T got;
do {
got = AtomicLoadRelaxedSystem(addr);
} while (got <= val);
return got;
}

const unsigned long long start = wall_clock64();
while (true) {
T got = AtomicLoadRelaxedSystem(addr);
if (got > val) return got;
if (!IsRankActive(activeRanks, watchedRank)) return got;
const unsigned long long now = wall_clock64();
if (now - start > static_cast<unsigned long long>(timeoutTicks)) {
MarkRankInactive(activeRanks, watchedRank);
return got;
}
}
}

template <typename T>
inline __device__ bool WaitUntilEqualsOrTimeoutSystem(T* addr, T val, int64_t timeoutTicks,
int32_t* activeRanks, int watchedRank) {
if (timeoutTicks < 0) {
while (AtomicLoadRelaxedSystem(addr) != val) {
}
return true;
}

const unsigned long long start = wall_clock64();
while (true) {
if (AtomicLoadRelaxedSystem(addr) == val) return true;
if (!IsRankActive(activeRanks, watchedRank)) return false;
const unsigned long long now = wall_clock64();
if (now - start > static_cast<unsigned long long>(timeoutTicks)) {
MarkRankInactive(activeRanks, watchedRank);
return false;
}
}
}

} // namespace core
} // namespace mori
Loading
Loading