diff --git a/examples/ops/dispatch_combine/test_dispatch_combine_internode.py b/examples/ops/dispatch_combine/test_dispatch_combine_internode.py index dfdc13824..d7bc071d3 100644 --- a/examples/ops/dispatch_combine/test_dispatch_combine_internode.py +++ b/examples/ops/dispatch_combine/test_dispatch_combine_internode.py @@ -50,10 +50,14 @@ def __init__( quant_type="none", dtype=torch.bfloat16, hidden_dim=7168, + drop_rank=-1, + timeout_us=0, ): self.rank = rank self.gpu_per_node = gpu_per_node self.world_size = world_size + self.drop_rank = drop_rank + self.timeout_us = timeout_us self.config = mori.ops.EpDispatchCombineConfig( data_type=dtype, rank=self.rank, @@ -100,6 +104,10 @@ def setup(self): self.rng = torch.Generator(device=self.device) self.rng.manual_seed(999) + self.active_ranks = torch.ones( + (self.world_size,), dtype=torch.int32, device=self.device + ) + def cleanup(self): mori.shmem.shmem_finalize() dist.destroy_process_group() @@ -256,6 +264,9 @@ def count_token_num(self, all_rank_indices): ) for src_rank, indices in enumerate(all_rank_indices): + if src_rank == self.drop_rank: + continue + src_node = src_rank // self.config.gpu_per_node # Map expert IDs to rank IDs @@ -292,20 +303,63 @@ def count_token_num(self, all_rank_indices): # print("Rank counts to other nodes:", rank_counts_remote_send) return rank_counts, rank_counts_remote_recv, rank_counts_remote_send - def run_dispatch(self, op, token, weights, scales, indices): + def run_dispatch(self, op, token, weights, scales, indices, is_active=True): + kwargs = {} + if self.timeout_us > 0: + kwargs["active_ranks"] = self.active_ranks + kwargs["timeout_us"] = self.timeout_us + + if not is_active: + # Simulate dropout by not calling dispatch + return ( + torch.empty( + (0, self.config.hidden_dim), + dtype=self.config.data_type, + device=self.device, + ), + torch.empty( + (0, op.config.num_experts_per_token), + dtype=torch.float32, + device=self.device, + ), + torch.empty((0,), dtype=torch.float32, device=self.device), + torch.empty( + (0, op.config.num_experts_per_token), + dtype=torch.int32, + device=self.device, + ), + torch.zeros((1,), dtype=torch.int32, device=self.device), + ) + if op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL: ret = op.dispatch_send(token, weights, scales, indices) op.dispatch_recv() else: - ret = op.dispatch(token, weights, scales, indices) + ret = op.dispatch(token, weights, scales, indices, **kwargs) return ret - def run_combine(self, op, token, weights, indices): + def run_combine(self, op, token, weights, indices, is_active=True): + kwargs = {} + if self.timeout_us > 0: + kwargs["active_ranks"] = self.active_ranks + kwargs["timeout_us"] = self.timeout_us + + if not is_active: + # Simulate dropout by not calling combine + return ( + torch.empty( + (indices.shape[0], self.config.hidden_dim), + dtype=self.config.data_type, + device=self.device, + ), + None, + ) + if op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL: ret = op.combine_send(token, weights, indices) op.combine_recv() else: - ret = op.combine(token, weights, indices) + ret = op.combine(token, weights, indices, **kwargs) return ret def run_test_once(self, op, test_data, error_round, round): @@ -317,6 +371,8 @@ def run_test_once(self, op, test_data, error_round, round): all_rank_scales, ) = test_data + is_active = self.rank != self.drop_rank + ( dispatch_output, dispatch_weights, @@ -329,9 +385,13 @@ def run_test_once(self, op, test_data, error_round, round): all_rank_weights[self.rank], all_rank_scales[self.rank], all_rank_indices[self.rank], + is_active=is_active, ) torch.cuda.synchronize() + if not is_active: + return + rank_counts, _, _ = self.count_token_num(all_rank_indices) src_token_pos = op.get_dispatch_src_token_pos().tolist() @@ -351,6 +411,9 @@ def run_test_once(self, op, test_data, error_round, round): for i, src_token_id in enumerate(src_token_pos): src_pe = src_token_id // max_num_token_to_send_per_rank src_tok_id = src_token_id % max_num_token_to_send_per_rank + assert ( + src_pe != self.drop_rank + ), f"Should not receive tokens from dropped rank {self.drop_rank}" if self.config.data_type is torch.float4_e2m1fn_x2: is_pass = torch.equal( dispatch_output[i].view(torch.uint8), @@ -393,11 +456,13 @@ def run_test_once(self, op, test_data, error_round, round): (idx // self.config.num_experts_per_rank) for idx in all_rank_indices[self.rank][i].cpu().tolist() ] - unique_pes = len(set(pes)) + + valid_pes = [p for p in pes if p != self.drop_rank] + unique_pes = len(set(valid_pes)) unique_innode_pes = len( [ pe - for pe in set(pes) + for pe in set(valid_pes) if (pe // self.gpu_per_node == self.rank // self.gpu_per_node) ] ) @@ -505,6 +570,7 @@ def stress_dispatch_combine(self): all_rank_weights, all_rank_scales, ) = test_data_list[i % num_test_data] + is_active = self.rank != self.drop_rank ( dispatch_output, dispatch_weights, @@ -517,9 +583,14 @@ def stress_dispatch_combine(self): all_rank_weights[self.rank], all_rank_scales[self.rank], all_rank_indices[self.rank], + is_active=is_active, ) combine_output, combine_output_weight = self.run_combine( - op, dispatch_output, None, all_rank_indices[self.rank] + op, + dispatch_output, + None, + all_rank_indices[self.rank], + is_active=is_active, ) if i % sync_interval == 0: torch.cuda.synchronize() @@ -539,6 +610,7 @@ def stress_dispatch_combine(self): all_rank_scales, ) = test_data g = torch.cuda.CUDAGraph() + is_active = self.rank != self.drop_rank with torch.cuda.graph(g): ( dispatch_output, @@ -552,9 +624,14 @@ def stress_dispatch_combine(self): all_rank_weights[self.rank], all_rank_scales[self.rank], all_rank_indices[self.rank], + is_active=is_active, ) combine_output, combine_output_weight = self.run_combine( - op, dispatch_output, None, all_rank_indices[self.rank] + op, + dispatch_output, + None, + all_rank_indices[self.rank], + is_active=is_active, ) torch.cuda.synchronize() @@ -577,7 +654,9 @@ def run_bench_once(self, max_num_token, op, test_data, repeat=10): all_rank_scales, ) = test_data + is_active = self.rank != self.drop_rank warmup_rounds = 3 + total_recv_num_token = 0 for i in range(warmup_rounds): ( dispatch_output, @@ -591,13 +670,19 @@ def run_bench_once(self, max_num_token, op, test_data, repeat=10): all_rank_weights[self.rank], all_rank_scales[self.rank], all_rank_indices[self.rank], + is_active=is_active, ) if i == warmup_rounds - 1: # Read totalRecvTokenNum after dispatch but before combine resets it torch.cuda.synchronize() - total_recv_num_token = dispatch_recv_num_token[0].item() + if is_active: + total_recv_num_token = dispatch_recv_num_token[0].item() combine_output, combine_output_weight = self.run_combine( - op, dispatch_output, None, all_rank_indices[self.rank] + op, + dispatch_output, + None, + all_rank_indices[self.rank], + is_active=is_active, ) torch.cuda.synchronize() total_rdma_recv_num_token = ( @@ -630,10 +715,15 @@ def run_bench_once(self, max_num_token, op, test_data, repeat=10): all_rank_weights[self.rank], all_rank_scales[self.rank], all_rank_indices[self.rank], + is_active=is_active, ) events[2 * i + 1].record() combine_output, combine_output_weight = self.run_combine( - op, dispatch_output, None, all_rank_indices[self.rank] + op, + dispatch_output, + None, + all_rank_indices[self.rank], + is_active=is_active, ) events[2 * i + 2].record() torch.cuda.synchronize() @@ -977,6 +1067,8 @@ def test_dispatch_combine( quant_type="none", cmd="test", sweep_token_interval=64, + drop_rank=-1, + timeout_us=0, ): world_size = num_node * gpu_per_node node_rank = int(os.environ["RANK"]) @@ -992,6 +1084,9 @@ def test_dispatch_combine( num_qp, quant_type, dtype, + hidden_dim=7168, + drop_rank=drop_rank, + timeout_us=timeout_us, ) test_case.setup() if cmd == "test": @@ -1075,6 +1170,18 @@ def test_dispatch_combine( "'fp8_direct_cast' is the current BF16<->FP8 direct cast path." ), ) +parser.add_argument( + "--drop-rank", + type=int, + default=-1, + help="Rank ID to simulate dropout to test elastic EP mechanism", +) +parser.add_argument( + "--timeout-us", + type=int, + default=0, + help="Timeout in microseconds for elastic EP mechanism polling", +) args_cli = parser.parse_args() if __name__ == "__main__": @@ -1095,6 +1202,8 @@ def test_dispatch_combine( args_cli.quant_type, args_cli.cmd, args_cli.sweep_token_interval, + args_cli.drop_rank, + args_cli.timeout_us, ), nprocs=gpu_per_node, join=True, diff --git a/include/mori/core/transport/p2p/device_primitives.hpp b/include/mori/core/transport/p2p/device_primitives.hpp index 27fb12b55..ea88d0536 100644 --- a/include/mori/core/transport/p2p/device_primitives.hpp +++ b/include/mori/core/transport/p2p/device_primitives.hpp @@ -799,8 +799,9 @@ __forceinline__ __device__ void WarpCastBf16ToCombineInternalFp8( // Note: when T != hip_bfloat16, this function is a no-op. // Callers should guard with if constexpr or ensure T is hip_bfloat16. #else - static_assert(!sizeof(T*), "WarpCastBf16ToCombineInternalFp8 requires FP8 type support " - "(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)"); + static_assert(!sizeof(T*), + "WarpCastBf16ToCombineInternalFp8 requires FP8 type support " + "(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)"); #endif } @@ -1033,10 +1034,76 @@ __forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16( // Note: when T != hip_bfloat16, this function is a no-op. // Callers should guard with if constexpr or ensure T is hip_bfloat16. #else - static_assert(!sizeof(T*), "WarpAccumCombineInternalFp8ToBf16 requires FP8 type support " - "(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)"); + static_assert(!sizeof(T*), + "WarpAccumCombineInternalFp8ToBf16 requires FP8 type support " + "(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)"); #endif } +/* ---------------------------------------------------------------------------------------------- */ +/* Elastic EP Utilities */ +/* ---------------------------------------------------------------------------------------------- */ +inline __device__ bool IsRankActive(const int32_t* activeRanks, int rank) { + if (activeRanks == nullptr) return true; + return AtomicLoadRelaxedSystem(const_cast(activeRanks) + rank) != 0; +} + +inline __device__ void MarkRankInactive(int32_t* activeRanks, int rank) { + if (activeRanks == nullptr) return; + AtomicStoreRelaxedSystem(activeRanks + rank, int32_t{0}); +} + +inline __device__ void MarkRanksInactive(int32_t* activeRanks, int startRank, int count) { + if (activeRanks == nullptr) return; + for (int i = 0; i < count; ++i) { + AtomicStoreRelaxedSystem(activeRanks + startRank + i, int32_t{0}); + } +} + +template +inline __device__ T WaitUntilGreaterThanOrTimeoutSystem(T* addr, T val, int64_t timeoutTicks, + int32_t* activeRanks, int watchedRank) { + if (timeoutTicks < 0) { + T got; + do { + got = AtomicLoadRelaxedSystem(addr); + } while (got <= val); + return got; + } + + const unsigned long long start = wall_clock64(); + while (true) { + T got = AtomicLoadRelaxedSystem(addr); + if (got > val) return got; + if (!IsRankActive(activeRanks, watchedRank)) return got; + const unsigned long long now = wall_clock64(); + if (now - start > static_cast(timeoutTicks)) { + MarkRankInactive(activeRanks, watchedRank); + return got; + } + } +} + +template +inline __device__ bool WaitUntilEqualsOrTimeoutSystem(T* addr, T val, int64_t timeoutTicks, + int32_t* activeRanks, int watchedRank) { + if (timeoutTicks < 0) { + while (AtomicLoadRelaxedSystem(addr) != val) { + } + return true; + } + + const unsigned long long start = wall_clock64(); + while (true) { + if (AtomicLoadRelaxedSystem(addr) == val) return true; + if (!IsRankActive(activeRanks, watchedRank)) return false; + const unsigned long long now = wall_clock64(); + if (now - start > static_cast(timeoutTicks)) { + MarkRankInactive(activeRanks, watchedRank); + return false; + } + } +} + } // namespace core } // namespace mori diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index c18818512..fd805d85d 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -114,6 +114,11 @@ class EpDispatchCombineHandle { EpDispatchCombineHandle(EpDispatchCombineConfig config); ~EpDispatchCombineHandle(); + void SetElasticState(int32_t* activeRanks, int64_t timeoutTicks) { + this->activeRanks = activeRanks; + this->timeoutTicks = timeoutTicks; + } + void PrepareInference(hipDataType inputType, void* input, void* output, float* weights, index_t* tokenIndices, index_t numToken) { this->inputType = inputType; @@ -180,13 +185,13 @@ class EpDispatchCombineHandle { void LaunchConvertDispatchOutputKernel(const void* dispatchOutX, const void* dispatchOutTopkIdx, void* packedRecvX, int* packedRecvCount, int* packedRecvSrcInfo, int64_t* packedRecvLayoutRange, - int blockNum = -1, int warpPerBlock = -1, - hipStream_t = 0, int hiddenDim = -1); + int blockNum = -1, int warpPerBlock = -1, hipStream_t = 0, + int hiddenDim = -1); void LaunchConvertCombineInputKernel(const void* packedRecvX, const void* packedRecvSrcInfo, const void* packedRecvLayoutRange, void* combineInput, mori::application::SymmMemObjPtr shmemCombineInpTokMemObj, - int blockNum = -1, int warpPerBlock = -1, - hipStream_t = 0, int hiddenDim = -1); + int blockNum = -1, int warpPerBlock = -1, hipStream_t = 0, + int hiddenDim = -1); #endif void LaunchDispatchRecv(KernelType, int blockNum = -1, int warpPerBlock = -1, hipStream_t = 0); @@ -213,6 +218,11 @@ class EpDispatchCombineHandle { index_t curRankNumToken{0}; index_t multiProcessorCount{0}; index_t maxThreads{0}; + int wallClockRateKHz{0}; + + // Elastic EP state (optional; if null/negative, elastic EP is disabled) + int32_t* activeRanks{nullptr}; + int64_t timeoutTicks{-1}; public: // Config @@ -325,6 +335,8 @@ struct EpDispatchCombineArgs { T* outTokenBuf{nullptr}; float* weightsBuf{nullptr}; uint8_t* scalesBuf{nullptr}; + int32_t* activeRanks{nullptr}; + int64_t timeoutTicks{-1}; mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; mori::application::SymmMemObjPtr shmemDispatchOutTokMemObj; @@ -401,6 +413,8 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.outTokenBuf = reinterpret_cast(handle.outTokenBuf); args.weightsBuf = handle.weightsBuf; args.scalesBuf = handle.scalesBuf; + args.activeRanks = handle.activeRanks; + args.timeoutTicks = handle.timeoutTicks; args.destPeTokenCounter = handle.destPeTokenCounter; args.localPeTokenCounter = handle.localPeTokenCounter; args.shmemDispatchInpTokMemObj = handle.shmemDispatchInpTokMemObj; diff --git a/include/mori/utils/hip_helper.hpp b/include/mori/utils/hip_helper.hpp index bf4b2e899..c45f1364d 100644 --- a/include/mori/utils/hip_helper.hpp +++ b/include/mori/utils/hip_helper.hpp @@ -48,16 +48,16 @@ inline int GetCurDeviceMaxThreads() { return GetMaxThreads(device); } -inline int GetDeviceWallClockFreqMhz(int device) { +inline int GetDeviceWallClockFreqKHz(int device) { int rate; HIP_RUNTIME_CHECK(hipDeviceGetAttribute(&rate, hipDeviceAttributeWallClockRate, device)); return rate; } -inline int GetCurDeviceWallClockFreqMhz() { +inline int GetCurDeviceWallClockFreqKHz() { int device = 0; HIP_RUNTIME_CHECK(hipGetDevice(&device)); - return GetDeviceWallClockFreqMhz(device); + return GetDeviceWallClockFreqKHz(device); } } // namespace mori diff --git a/python/mori/kernel_profiler/__init__.py b/python/mori/kernel_profiler/__init__.py index 9b04d2467..54fbb37af 100644 --- a/python/mori/kernel_profiler/__init__.py +++ b/python/mori/kernel_profiler/__init__.py @@ -146,7 +146,7 @@ def export_to_perfetto( sanitize_orphans=True, ): if not gpu_freq_ghz: - gpu_freq_ghz = mori_cpp.get_cur_device_wall_clock_freq_mhz() / 1e6 + gpu_freq_ghz = mori_cpp.get_cur_device_wall_clock_freq_khz() / 1e6 if slot_map is None: slot_map = _discover_all_slots() diff --git a/python/mori/ops/dispatch_combine.py b/python/mori/ops/dispatch_combine.py index a686dd7db..019c60af5 100644 --- a/python/mori/ops/dispatch_combine.py +++ b/python/mori/ops/dispatch_combine.py @@ -204,6 +204,8 @@ def dispatch( block_num: int = -1, rdma_block_num: int = -1, warp_per_block: int = -1, + active_ranks=None, + timeout_us=None, ): """Dispatch tokens to experts based on top-k indices. @@ -225,6 +227,8 @@ def dispatch( self.auto_block_num if self.auto_block_num else block_num, self.auto_rdma_block_num if self.auto_rdma_block_num else rdma_block_num, self.auto_warp_per_block if self.auto_warp_per_block else warp_per_block, + active_ranks, + timeout_us, ) def dispatch_send( @@ -235,6 +239,8 @@ def dispatch_send( indices: torch.Tensor, block_num: int = -1, warp_per_block: int = -1, + active_ranks=None, + timeout_us=None, ): return self.dispatch( input, @@ -243,18 +249,24 @@ def dispatch_send( indices, self.auto_block_num if self.auto_block_num else block_num, self.auto_warp_per_block if self.auto_warp_per_block else warp_per_block, + active_ranks=active_ranks, + timeout_us=timeout_us, ) def dispatch_recv( self, block_num: int = -1, warp_per_block: int = -1, + active_ranks=None, + timeout_us=None, ): return self._dispatch_recv_func( self._handle, self.config.kernel_type.value, self.auto_block_num if self.auto_block_num else block_num, self.auto_warp_per_block if self.auto_warp_per_block else warp_per_block, + active_ranks, + timeout_us, ) def combine( @@ -267,6 +279,8 @@ def combine( warp_per_block: int = -1, use_external_inp_buf: int = -1, call_reset: bool = False, + active_ranks=None, + timeout_us=None, ): """Combine tokens from experts back to original positions. @@ -291,6 +305,8 @@ def combine( self.auto_rdma_block_num if self.auto_rdma_block_num else rdma_block_num, self.auto_warp_per_block if self.auto_warp_per_block else warp_per_block, use_external_inp_buf, + active_ranks, + timeout_us, ) if call_reset: self._reset_func(self._handle) @@ -303,6 +319,8 @@ def combine_send( indices: torch.Tensor, block_num: int = -1, warp_per_block: int = -1, + active_ranks=None, + timeout_us=None, ): return self.combine( input, @@ -310,18 +328,24 @@ def combine_send( indices, self.auto_block_num if self.auto_block_num else block_num, self.auto_warp_per_block if self.auto_warp_per_block else warp_per_block, + active_ranks=active_ranks, + timeout_us=timeout_us, ) def combine_recv( self, block_num: int = -1, warp_per_block: int = -1, + active_ranks=None, + timeout_us=None, ): return self._combine_recv_func( self._handle, self.config.kernel_type.value, self.auto_block_num if self.auto_block_num else block_num, self.auto_warp_per_block if self.auto_warp_per_block else warp_per_block, + active_ranks, + timeout_us, ) def dispatch_standard_moe( diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index c6d36df83..1330defb6 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -21,12 +21,13 @@ // SOFTWARE. #include "mori/ops/dispatch_combine/dispatch_combine.hpp" -#include #include #include #include #include +#include + #include "mori/core/core.hpp" #include "mori/shmem/shmem.hpp" #include "mori/utils/hip_helper.hpp" @@ -65,6 +66,7 @@ EpDispatchCombineHandle::EpDispatchCombineHandle(EpDispatchCombineConfig config_ this->maxThreads = std::min(GetCurDeviceMaxThreads(), 1024); MORI_OPS_INFO("Device capability: multiProcessorCount=%d, maxThreads=%d", static_cast(this->multiProcessorCount), static_cast(this->maxThreads)); + this->wallClockRateKHz = GetCurDeviceWallClockFreqKHz(); } EpDispatchCombineHandle::~EpDispatchCombineHandle() { diff --git a/src/ops/dispatch_combine/internode_v1.cpp b/src/ops/dispatch_combine/internode_v1.cpp index 6c5dd174d..793612a7b 100644 --- a/src/ops/dispatch_combine/internode_v1.cpp +++ b/src/ops/dispatch_combine/internode_v1.cpp @@ -48,6 +48,11 @@ inline __device__ void DispatchIntraNodeBlock(EpDispatchCombineArgs& args, in DEF_COMMON_VARS; index_t tokenExpertId = tokenId * args.config.numExpertPerToken + expId; + if (!core::IsRankActive(args.activeRanks, destPe)) { + if (laneId == 0) args.dispDestTokIdMap[tokenExpertId] = nullTokenId; + return; + } + index_t destTokId = 0; if (laneId == 0) { // decide token id in dest pe @@ -143,6 +148,7 @@ inline __device__ void DispatchInterNodeSend(EpDispatchCombineArgs& args) { for (int i = warpId; i < nNodes; i += warpNum) { if (i == myNode) continue; int proxyPe = i * config.gpuPerNode + (config.rank % config.gpuPerNode); + const bool nodeActive = core::IsRankActive(args.activeRanks, i * config.gpuPerNode); if (DEDUP) { for (int tokenId = startTokenIdx + laneId; tokenId < endTokenIdx; tokenId += warpSize) { bool shouldSend = false; @@ -158,6 +164,7 @@ inline __device__ void DispatchInterNodeSend(EpDispatchCombineArgs& args) { uint64_t num = __popcll(mask); if (num == 0) continue; + if (!nodeActive) continue; index_t flag = 0; index_t flagSlotId = 0; @@ -211,6 +218,7 @@ inline __device__ void DispatchInterNodeSend(EpDispatchCombineArgs& args) { args.dispDestTokIdMap[tokenId * numExpertPerToken + e] = nullTokenId; } } + if (!nodeActive) continue; index_t flagSlotId = 0; if (laneId == 0) { @@ -245,9 +253,10 @@ inline __device__ void DispatchInterNodeSend(EpDispatchCombineArgs& args) { int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); index_t numTokenSignal = core::AtomicLoadRelaxed(args.blockFlagCounter + laneId) * warpSize + 1; - shmem::ShmemAtomicTypeNonFetchThread(args.nodeRecvTokenNumMemObj, - myNode * sizeof(uint64_t), numTokenSignal, - core::AMO_ADD, proxyPe); + if (core::IsRankActive(args.activeRanks, laneId * config.gpuPerNode)) + shmem::ShmemAtomicTypeNonFetchThread(args.nodeRecvTokenNumMemObj, + myNode * sizeof(uint64_t), numTokenSignal, + core::AMO_ADD, proxyPe); } if (laneId == 0) args.interNodeBlocksBarrier[0] = 0; } @@ -268,6 +277,7 @@ inline __device__ void DispatchInterNodeLLSend(EpDispatchCombineArgs& args) { for (int i = warpId; i < nNodes; i += warpNum) { if (i == myNode) continue; int proxyPe = i * config.gpuPerNode + (config.rank % config.gpuPerNode); + const bool nodeActive = core::IsRankActive(args.activeRanks, i * config.gpuPerNode); for (int tokenId = chunkStartTokenIdx + laneId; tokenId < chunkEndTokenIdx; tokenId += warpSize) { @@ -281,6 +291,8 @@ inline __device__ void DispatchInterNodeLLSend(EpDispatchCombineArgs& args) { } } + if (!nodeActive) continue; + index_t flagSlotId = 0; if (laneId == 0) { flagSlotId = atomicAdd(args.blockFlagCounter + i, 1); @@ -314,9 +326,10 @@ inline __device__ void DispatchInterNodeLLSend(EpDispatchCombineArgs& args) { int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); index_t numTokenSignal = core::AtomicLoadRelaxed(args.blockFlagCounter + laneId) * warpSize + 1; - shmem::ShmemAtomicTypeNonFetchThread(args.nodeRecvTokenNumMemObj, - myNode * sizeof(uint64_t), numTokenSignal, - core::AMO_ADD, proxyPe); + if (core::IsRankActive(args.activeRanks, laneId * config.gpuPerNode)) + shmem::ShmemAtomicTypeNonFetchThread(args.nodeRecvTokenNumMemObj, + myNode * sizeof(uint64_t), numTokenSignal, + core::AMO_ADD, proxyPe); } if (laneId == 0) args.interNodeBlocksBarrier[1] = 0; } @@ -348,7 +361,13 @@ inline __device__ void DispatchInterNodeRecv(EpDispatchCombineArgs& args) { uint64_t thisChunkTokenNum = 0; index_t nodeFlag = 0; if (laneId == 0) { + const int nodeBaseRank = node * config.gpuPerNode; + unsigned long long start = wall_clock64(); while (1) { + if (!core::IsRankActive(args.activeRanks, nodeBaseRank)) { + thisChunkTokenNum = 1; + break; + } thisChunkTokenNum = core::AtomicLoadRelaxedSystem(&chunkFlag[node * maxChunkNum + k]); if (thisChunkTokenNum > 0) break; @@ -357,6 +376,12 @@ inline __device__ void DispatchInterNodeRecv(EpDispatchCombineArgs& args) { thisChunkTokenNum = 1; break; } + if ((args.timeoutTicks >= 0) && + (wall_clock64() - start > static_cast(args.timeoutTicks))) { + core::MarkRanksInactive(args.activeRanks, nodeBaseRank, config.gpuPerNode); + thisChunkTokenNum = 1; + break; + } } } thisChunkTokenNum = __shfl(thisChunkTokenNum, 0) - 1; @@ -381,7 +406,8 @@ inline __device__ void DispatchInterNodeRecv(EpDispatchCombineArgs& args) { int destPe = __shfl(lanePe, e); int destNode = destPe / config.gpuPerNode; - bool shouldSkip = (destNode != myNode) || __any((laneId < e) && (destPe == lanePe)); + bool shouldSkip = (destNode != myNode) || __any((laneId < e) && (destPe == lanePe)) || + !core::IsRankActive(args.activeRanks, destPe); if (shouldSkip) { if (laneId == 0) args.interNodeDispDestTokIdMap[tokIdx * config.numExpertPerToken + e] = nullTokenId; @@ -452,7 +478,13 @@ inline __device__ void DispatchInterNodeLLRecv(EpDispatchCombineArgs& args) { index_t nodeFlag = 0; if (laneId == 0) { uint64_t barrierFlag = args.crossDeviceBarrierFlag[0]; + const int nodeBaseRank = node * config.gpuPerNode; + unsigned long long start = wall_clock64(); while (1) { + if (!core::IsRankActive(args.activeRanks, nodeBaseRank)) { + thisChunkTokenNum = 1; + break; + } thisChunkTokenNum = core::AtomicLoadRelaxedSystem(&chunkFlag[node * maxChunkNum + k]); if (thisChunkTokenNum > 0) break; @@ -461,6 +493,12 @@ inline __device__ void DispatchInterNodeLLRecv(EpDispatchCombineArgs& args) { thisChunkTokenNum = 1; break; } + if ((args.timeoutTicks >= 0) && + (wall_clock64() - start > static_cast(args.timeoutTicks))) { + core::MarkRanksInactive(args.activeRanks, nodeBaseRank, config.gpuPerNode); + thisChunkTokenNum = 1; + break; + } } } thisChunkTokenNum = __shfl(thisChunkTokenNum, 0) - 1; @@ -481,7 +519,8 @@ inline __device__ void DispatchInterNodeLLRecv(EpDispatchCombineArgs& args) { int destPe = __shfl(lanePe, expertId); int destNode = destPe / config.gpuPerNode; - bool shouldSkip = (destNode != myNode) || __any((laneId < expertId) && (destPe == lanePe)); + bool shouldSkip = (destNode != myNode) || __any((laneId < expertId) && (destPe == lanePe)) || + !core::IsRankActive(args.activeRanks, destPe); if (shouldSkip) { if (laneId == 0) args.interNodeDispDestTokIdMap[globalTokenId * config.numExpertPerToken + expertId] = @@ -534,21 +573,33 @@ inline __device__ void DispatchSync(EpDispatchCombineArgs& args) { if ((finishedWarp + 1) == globalWarpNum) { if (laneId < config.gpuPerNode) { int destPe = myNode * config.gpuPerNode + laneId; - index_t numTokenSignal = core::AtomicLoadSeqCstSystem(args.destPeTokenCounter + destPe) + 1; - index_t* signal = args.recvTokenNumMemObj->template GetAs(destPe) + myPe; - core::AtomicStoreSeqCstSystem(signal, numTokenSignal); + if (!core::IsRankActive(args.activeRanks, destPe)) { + // Skip signaling an inactive peer. + } else { + index_t numTokenSignal = core::AtomicLoadSeqCstSystem(args.destPeTokenCounter + destPe) + 1; + index_t* signal = args.recvTokenNumMemObj->template GetAs(destPe) + myPe; + const bool ok = core::WaitUntilEqualsOrTimeoutSystem(signal, index_t{0}, args.timeoutTicks, + args.activeRanks, destPe); + if (ok) core::AtomicStoreSeqCstSystem(signal, numTokenSignal); + } } if (laneId == 0) args.dispatchGridBarrier[0] = 0; index_t* recvTokenNums = args.recvTokenNumMemObj->template GetAs(); for (int destPe = nodePeOffset + laneId; destPe < (nodePeOffset + config.gpuPerNode); destPe += warpSize) { + if (!core::IsRankActive(args.activeRanks, destPe)) { + core::AtomicStoreSeqCstSystem(args.destPeTokenCounter + destPe, 0); + continue; + } index_t* signal = recvTokenNums + destPe; - index_t recvTokenNum = shmem::ShmemInt32WaitUntilGreaterThan(signal, 0) - 1; + index_t got = core::WaitUntilGreaterThanOrTimeoutSystem(signal, index_t{0}, args.timeoutTicks, + args.activeRanks, destPe); + index_t recvTokenNum = (got > 0) ? (got - 1) : 0; atomicAdd(args.totalRecvTokenNum, recvTokenNum); __threadfence_system(); // reset local counter - core::AtomicStoreSeqCstSystem(signal, 0); + if (got > 0) core::AtomicStoreSeqCstSystem(signal, 0); core::AtomicStoreSeqCstSystem(args.destPeTokenCounter + destPe, 0); } @@ -561,7 +612,8 @@ inline __device__ void DispatchSync(EpDispatchCombineArgs& args) { for (int i = globalWarpId; i < nNodes; i += globalWarpNum) { int proxyPe = i * config.gpuPerNode + (config.rank % config.gpuPerNode); - shmem::ShmemQuietThread(proxyPe); + if (core::IsRankActive(args.activeRanks, i * config.gpuPerNode)) + shmem::ShmemQuietThread(proxyPe); } } @@ -705,7 +757,7 @@ __forceinline__ __device__ void CombineIntraNodeTyped(EpDispatchCombineArgs& index_t destTokId = args.dispDestTokIdMap[tokenId * config.numExpertPerToken + laneId]; index_t destPe = destTokId / config.MaxNumTokensToRecv(); index_t destNode = destPe / config.gpuPerNode; - if (destNode == myNode) { + if ((destNode == myNode) && core::IsRankActive(args.activeRanks, destPe)) { index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecv(); srcPtrs[laneId] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim; @@ -758,7 +810,7 @@ __forceinline__ __device__ void CombineIntraNodeLLTyped(EpDispatchCombineArgs index_t destTokId = args.dispDestTokIdMap[tokenId * config.numExpertPerToken + laneId]; index_t destPe = destTokId / config.MaxNumTokensToRecv(); index_t destNode = destPe / config.gpuPerNode; - if (destNode == myNode) { + if ((destNode == myNode) && core::IsRankActive(args.activeRanks, destPe)) { index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecv(); srcPtrs[laneId] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim + hiddenDimOffset; @@ -785,6 +837,7 @@ __forceinline__ __device__ void CombineInterNodeTyped(EpDispatchCombineArgs& constexpr int numRecvBlock = 8; int maxChunkNum = core::CeilDiv(config.maxNumInpTokenPerRank, warpSize); + const unsigned long long start = (args.timeoutTicks >= 0) ? wall_clock64() : 0; uint64_t* chunkFlag = args.interNodeChunkFlagMemObj->template GetAs(); index_t* nodeRecvTokenNum = args.nodeRecvTokenNumMemObj->template GetAs(); @@ -828,11 +881,21 @@ __forceinline__ __device__ void CombineInterNodeTyped(EpDispatchCombineArgs& int startTokenIdx = k * warpSize; if (laneId == 0) { - thisChunkTokenNum = chunkFlag[node * maxChunkNum + k]; - if (thisChunkTokenNum == 0) { - index_t nodeFlag = core::AtomicLoadRelaxedSystem(&nodeRecvTokenNum[node]); - if ((nodeFlag > 0) && (startTokenIdx >= (nodeFlag - 1))) { - thisChunkTokenNum = 1; + const int nodeBaseRank = node * config.gpuPerNode; + if (!core::IsRankActive(args.activeRanks, nodeBaseRank)) { + thisChunkTokenNum = 1; + } else { + thisChunkTokenNum = core::AtomicLoadRelaxedSystem(&chunkFlag[node * maxChunkNum + k]); + if (thisChunkTokenNum == 0) { + index_t nodeFlag = core::AtomicLoadRelaxedSystem(&nodeRecvTokenNum[node]); + if ((nodeFlag > 0) && (startTokenIdx >= (nodeFlag - 1))) { + thisChunkTokenNum = 1; + } else if ((args.timeoutTicks >= 0) && + (wall_clock64() - start > + static_cast(args.timeoutTicks))) { + core::MarkRanksInactive(args.activeRanks, nodeBaseRank, config.gpuPerNode); + thisChunkTokenNum = 1; + } } } } @@ -853,7 +916,7 @@ __forceinline__ __device__ void CombineInterNodeTyped(EpDispatchCombineArgs& args.interNodeDispDestTokIdMap[tokIdx * config.numExpertPerToken + laneId]; index_t destPe = destTokId / config.MaxNumTokensToRecv(); index_t destNode = destPe / config.gpuPerNode; - if (destNode == myNode) { + if ((destNode == myNode) && core::IsRankActive(args.activeRanks, destPe)) { index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecv(); srcPtrs[laneId] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim; @@ -889,15 +952,18 @@ __forceinline__ __device__ void CombineInterNodeTyped(EpDispatchCombineArgs& core::AtomicStoreRelaxedSystem( args.interNodeChunkFlagCombine + node * maxChunkNum + k, index_t{0}); } - int proxyPe = node * config.gpuPerNode + (config.rank % config.gpuPerNode); - int qpId = k % config.numQpPerPe; - shmem::ShmemPutTypeNbiWarp( - args.shmemStagingTokMemObj, - ((myNode + nNodes) * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * - tokCombXferBytes, - args.shmemStagingTokMemObj, - (node * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * tokCombXferBytes, - thisChunkTokenNum * tokCombXferBytes, proxyPe, qpId); + if ((thisChunkTokenNum > 0) && + core::IsRankActive(args.activeRanks, node * config.gpuPerNode)) { + int proxyPe = node * config.gpuPerNode + (config.rank % config.gpuPerNode); + int qpId = k % config.numQpPerPe; + shmem::ShmemPutTypeNbiWarp( + args.shmemStagingTokMemObj, + ((myNode + nNodes) * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * + tokCombXferBytes, + args.shmemStagingTokMemObj, + (node * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * tokCombXferBytes, + thisChunkTokenNum * tokCombXferBytes, proxyPe, qpId); + } } } processedMask |= (1u << relativeIdx); @@ -930,11 +996,13 @@ __forceinline__ __device__ void CombineInterNodeTyped(EpDispatchCombineArgs& } if ((laneId < nNodes) && (laneId != myNode)) { // avoid setting myNode, it will be set in intra node branch - int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); - for (int i = 0; i < config.numQpPerPe; i++) { - shmem::ShmemAtomicTypeNonFetchThread(args.crossDeviceBarrierMemObj, - args.config.rank * sizeof(uint64_t), 1, - core::AMO_ADD, proxyPe, i); + if (core::IsRankActive(args.activeRanks, laneId * config.gpuPerNode)) { + int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); + for (int i = 0; i < config.numQpPerPe; i++) { + shmem::ShmemAtomicTypeNonFetchThread(args.crossDeviceBarrierMemObj, + args.config.rank * sizeof(uint64_t), 1, + core::AMO_ADD, proxyPe, i); + } } } if (laneId == 0) args.interNodeBlocksBarrier[0] = 0; @@ -942,8 +1010,14 @@ __forceinline__ __device__ void CombineInterNodeTyped(EpDispatchCombineArgs& uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); if ((laneId < nNodes) && (laneId != myNode)) { int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); - while (core::AtomicLoadRelaxedSystem(localBarrierPtr + proxyPe) != - (barrierFlag * config.numQpPerPe)) { + if (core::IsRankActive(args.activeRanks, laneId * config.gpuPerNode)) { + const uint64_t expected = barrierFlag * config.numQpPerPe; + const bool ok = core::WaitUntilEqualsOrTimeoutSystem( + localBarrierPtr + proxyPe, expected, args.timeoutTicks, args.activeRanks, proxyPe); + if (!ok) { + core::MarkRanksInactive(args.activeRanks, laneId * config.gpuPerNode, config.gpuPerNode); + core::AtomicStoreRelaxedSystem(localBarrierPtr + proxyPe, expected); + } } } } @@ -970,6 +1044,7 @@ __forceinline__ __device__ void CombineInterNodeLLTyped(EpDispatchCombineArgs int rdmaWarpNum = args.rdmaBlockNum * warpNum; for (int n = 0; n < (nNodes - 1); n++) { int node = (myNode + n + 1) % nNodes; + if (!core::IsRankActive(args.activeRanks, node * config.gpuPerNode)) continue; uint64_t nodeCount = nodeRecvTokenNum[node]; if (nodeCount > 0) nodeCount -= 1; if (nodeCount == 0) continue; @@ -1000,7 +1075,7 @@ __forceinline__ __device__ void CombineInterNodeLLTyped(EpDispatchCombineArgs args.interNodeDispDestTokIdMap[globalTokenId * config.numExpertPerToken + laneId]; index_t destPe = destTokId / config.MaxNumTokensToRecv(); index_t destNode = destPe / config.gpuPerNode; - if (destNode == myNode) { + if ((destNode == myNode) && core::IsRankActive(args.activeRanks, destPe)) { index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecv(); srcPtrs[laneId] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim + hiddenDimOffset; @@ -1032,15 +1107,18 @@ __forceinline__ __device__ void CombineInterNodeLLTyped(EpDispatchCombineArgs core::AtomicStoreRelaxedSystem(args.interNodeChunkFlagCombine + node * maxChunkNum + k, index_t{0}); } - int proxyPe = node * config.gpuPerNode + (config.rank % config.gpuPerNode); - int qpId = k % config.numQpPerPe; - shmem::ShmemPutTypeNbiWarp( - args.shmemStagingTokMemObj, - ((myNode + nNodes) * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * - tokCombXferBytes, - args.shmemStagingTokMemObj, - (node * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * tokCombXferBytes, - thisChunkTokenNum * tokCombXferBytes, proxyPe, qpId); + if ((thisChunkTokenNum > 0) && + core::IsRankActive(args.activeRanks, node * config.gpuPerNode)) { + int proxyPe = node * config.gpuPerNode + (config.rank % config.gpuPerNode); + int qpId = k % config.numQpPerPe; + shmem::ShmemPutTypeNbiWarp( + args.shmemStagingTokMemObj, + ((myNode + nNodes) * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * + tokCombXferBytes, + args.shmemStagingTokMemObj, + (node * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * tokCombXferBytes, + thisChunkTokenNum * tokCombXferBytes, proxyPe, qpId); + } } } } @@ -1065,13 +1143,15 @@ __forceinline__ __device__ void CombineInterNodeLLTyped(EpDispatchCombineArgs } if ((laneId < nNodes) && (laneId != myNode)) { // avoid setting myNode, it will be set in intra node branch - int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); - for (int i = 0; i < config.numQpPerPe; i++) { - shmem::ShmemAtomicTypeNonFetchThread(args.crossDeviceBarrierMemObj, - args.config.rank * sizeof(uint64_t), 1, - core::AMO_ADD, proxyPe, i); + if (core::IsRankActive(args.activeRanks, laneId * config.gpuPerNode)) { + int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); + for (int i = 0; i < config.numQpPerPe; i++) { + shmem::ShmemAtomicTypeNonFetchThread(args.crossDeviceBarrierMemObj, + args.config.rank * sizeof(uint64_t), 1, + core::AMO_ADD, proxyPe, i); + } + __threadfence_system(); } - __threadfence_system(); } if (laneId == 0) args.interNodeBlocksBarrier[0] = 0; @@ -1079,8 +1159,14 @@ __forceinline__ __device__ void CombineInterNodeLLTyped(EpDispatchCombineArgs uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); if ((laneId < nNodes) && (laneId != myNode)) { int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); - while (core::AtomicLoadRelaxedSystem(localBarrierPtr + proxyPe) != - (barrierFlag * config.numQpPerPe)) { + if (core::IsRankActive(args.activeRanks, laneId * config.gpuPerNode)) { + const uint64_t expected = barrierFlag * config.numQpPerPe; + const bool ok = core::WaitUntilEqualsOrTimeoutSystem( + localBarrierPtr + proxyPe, expected, args.timeoutTicks, args.activeRanks, proxyPe); + if (!ok) { + core::MarkRanksInactive(args.activeRanks, laneId * config.gpuPerNode, config.gpuPerNode); + core::AtomicStoreRelaxedSystem(localBarrierPtr + proxyPe, expected); + } } } } @@ -1194,6 +1280,7 @@ __forceinline__ __device__ void EpCombineAllInternalFp8(EpDispatchCombineArgs if (laneId < config.numExpertPerToken) { lanePe = (args.tokenIndices[tokenId * numExpertPerToken + laneId] / config.numExpertPerRank); laneNode = lanePe / config.gpuPerNode; + if (!core::IsRankActive(args.activeRanks, laneNode * config.gpuPerNode)) laneNode = -1; } if (laneId < nNodes) { @@ -1214,7 +1301,7 @@ __forceinline__ __device__ void EpCombineAllInternalFp8(EpDispatchCombineArgs T* out = args.shmemCombineOutTokMemObj->template GetAs() + tokenId * config.hiddenDim + hiddenDimOffset; core::WarpAccumCombineInternalFp8ToBf16(out, reinterpret_cast(srcPtrs), - nNodes, laneId, hiddenDimSize); + nNodes, laneId, hiddenDimSize); if (args.weightsBuf && (inTokenPartId == warpsPerToken - 1)) { core::WarpAccum(args.shmemCombineOutWeightsMemObj->template GetAs() + @@ -1249,6 +1336,7 @@ __forceinline__ __device__ void EpCombineAllGeneric(EpDispatchCombineArgs& ar if (laneId < config.numExpertPerToken) { lanePe = (args.tokenIndices[tokenId * numExpertPerToken + laneId] / config.numExpertPerRank); laneNode = lanePe / config.gpuPerNode; + if (!core::IsRankActive(args.activeRanks, laneNode * config.gpuPerNode)) laneNode = -1; } if (laneId < nNodes) { @@ -1334,10 +1422,17 @@ __global__ void EpCombineSyncBarrier(EpDispatchCombineArgs args) { uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); if (laneId < config.gpuPerNode) { int destPe = myNode * config.gpuPerNode + laneId; - core::AtomicStoreRelaxedSystem( - args.crossDeviceBarrierMemObj->template GetAs(destPe) + args.config.rank, - barrierFlag); - while (core::AtomicLoadRelaxedSystem(localBarrierPtr + destPe) != barrierFlag) { + if (!core::IsRankActive(args.activeRanks, destPe)) { + core::AtomicStoreRelaxedSystem(localBarrierPtr + destPe, barrierFlag); + } else { + core::AtomicStoreRelaxedSystem( + args.crossDeviceBarrierMemObj->template GetAs(destPe) + args.config.rank, + barrierFlag); + const bool ok = core::WaitUntilEqualsOrTimeoutSystem( + localBarrierPtr + destPe, barrierFlag, args.timeoutTicks, args.activeRanks, destPe); + if (!ok) { + core::AtomicStoreRelaxedSystem(localBarrierPtr + destPe, barrierFlag); + } } } } diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index 7da837a67..b32b742a2 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -55,16 +55,26 @@ inline __device__ void CrossDeviceBarrierIntraNodeKernel(EpDispatchCombineArgstemplate GetAs(globalThdId) + args.config.rank, - crossDeviceBarrierFlag); + if (core::IsRankActive(args.activeRanks, globalThdId)) { + core::AtomicStoreRelaxedSystem( + args.crossDeviceBarrierMemObj->template GetAs(globalThdId) + args.config.rank, + crossDeviceBarrierFlag); + } } if (globalThdId == 0) atomicAdd(args.crossDeviceBarrierFlag, 1); uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); if (thdId < args.config.worldSize) { - while (core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId) != crossDeviceBarrierFlag) { + if (!core::IsRankActive(args.activeRanks, thdId)) { + core::AtomicStoreRelaxedSystem(localBarrierPtr + thdId, crossDeviceBarrierFlag); + } else { + const bool ok = + core::WaitUntilEqualsOrTimeoutSystem(localBarrierPtr + thdId, crossDeviceBarrierFlag, + args.timeoutTicks, args.activeRanks, thdId); + if (!ok) { + core::AtomicStoreRelaxedSystem(localBarrierPtr + thdId, crossDeviceBarrierFlag); + } } } __syncthreads(); @@ -116,6 +126,12 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { continue; } + if (!core::IsRankActive(args.activeRanks, destPe)) { + // Treat as skipped so combine ignores this slot. + if (laneId == 0) args.dispDestTokIdMap[i] = config.worldSize * maxNumTokensToSend; + continue; + } + if (laneId == 0) { // decide token id in dest pe destTokId = atomicAdd(args.dispTokOffsetMemObj->template GetAs(destPe), 1); @@ -165,10 +181,14 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { shmem::ShmemUint32WaitUntilEquals(args.dispatchGridBarrier, gridDim.x); args.dispatchGridBarrier[0] = 0; + if (!core::IsRankActive(args.activeRanks, destPe)) continue; + // Add 1 so that when token number == 0, receiver side still know the signal is sent index_t numTokenSignal = core::AtomicLoadRelaxed(args.destPeTokenCounter + destPe) + 1; index_t* signal = args.recvTokenNumMemObj->template GetAs(destPe) + myPe; - shmem::ShmemInt32WaitUntilEquals(signal, 0); + const bool ok = core::WaitUntilEqualsOrTimeoutSystem(signal, index_t{0}, args.timeoutTicks, + args.activeRanks, destPe); + if (!ok) continue; core::AtomicStoreRelaxedSystem(signal, numTokenSignal); } } @@ -178,9 +198,16 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { index_t* recvTokenNums = args.recvTokenNumMemObj->template GetAs(); if (globalWarpId == 0) { for (int destPe = laneId; destPe < npes; destPe += warpSize) { + if (!core::IsRankActive(args.activeRanks, destPe)) { + // reset local counter + args.destPeTokenCounter[destPe] = 0; + continue; + } index_t* signal = recvTokenNums + destPe; - index_t recvTokenNum = shmem::ShmemInt32WaitUntilGreaterThan(signal, 0) - 1; - core::AtomicStoreRelaxedSystem(signal, 0); + index_t got = core::WaitUntilGreaterThanOrTimeoutSystem(signal, index_t{0}, args.timeoutTicks, + args.activeRanks, destPe); + index_t recvTokenNum = (got > 0) ? (got - 1) : 0; + if (got > 0) core::AtomicStoreRelaxedSystem(signal, 0); atomicAdd(args.totalRecvTokenNum, recvTokenNum); // reset local counter @@ -251,9 +278,8 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { args.shmemCombineInpTokMemObj->template GetAs() + i * config.hiddenDim, args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim, laneId); } else { - core::WarpCopy( - args.shmemCombineInpTokMemObj->template GetAs() + i * config.hiddenDim, - args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim); + core::WarpCopy(args.shmemCombineInpTokMemObj->template GetAs() + i * config.hiddenDim, + args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim); } } } @@ -269,14 +295,15 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { index_t destTokId = args.dispTokIdToSrcTokIdMemObj->template GetAs(myPe)[tokenIdx]; index_t destPe = destTokId / config.MaxNumTokensToRecvPerRank(); index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecvPerRank(); + if (!core::IsRankActive(args.activeRanks, destPe)) continue; uint8_t* destStagingPtr = args.shmemCombineInpTokMemObj->template GetAs(destPe) + (myPe * config.MaxNumTokensToRecvPerRank() + destLocalTokId) * combXferBytes; if constexpr (!std::is_same_v && std::is_same_v) { // bf16 -> fp8 conversion - core::WarpCastBf16ToCombineInternalFp8( - reinterpret_cast(destStagingPtr), - args.inpTokenBuf + tokenIdx * config.hiddenDim, config.hiddenDim, laneId); + core::WarpCastBf16ToCombineInternalFp8(reinterpret_cast(destStagingPtr), + args.inpTokenBuf + tokenIdx * config.hiddenDim, + config.hiddenDim, laneId); } else { core::WarpCopy(reinterpret_cast(destStagingPtr), args.inpTokenBuf + tokenIdx * config.hiddenDim, config.hiddenDim); @@ -316,7 +343,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { index_t destTokId = args.dispDestTokIdMap[tokenId * config.numExpertPerToken + j]; index_t destPe = destTokId / maxNumTokensToSend; - if (destPe < config.worldSize) { + if ((destPe < config.worldSize) && core::IsRankActive(args.activeRanks, destPe)) { if constexpr (UseP2PRead) { index_t destLocalTokId = destTokId - destPe * maxNumTokensToSend; srcPtrs[j] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + @@ -340,8 +367,8 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { } } - T* outPtr = args.shmemCombineOutTokMemObj->template GetAs() + - tokenId * config.hiddenDim + hiddenDimOffset; + T* outPtr = args.shmemCombineOutTokMemObj->template GetAs() + tokenId * config.hiddenDim + + hiddenDimOffset; int validAccumCount = config.numExpertPerToken; if (config.worldSize <= 4) { @@ -360,11 +387,10 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { } } } - + if constexpr (!std::is_same_v && std::is_same_v) { - core::WarpAccumCombineInternalFp8ToBf16( - outPtr, reinterpret_cast(srcPtrs), - validAccumCount, laneId, hiddenDimSize); + core::WarpAccumCombineInternalFp8ToBf16(outPtr, reinterpret_cast(srcPtrs), + validAccumCount, laneId, hiddenDimSize); } else { core::WarpAccum(outPtr, srcPtrs, nullptr, validAccumCount, hiddenDimSize); } diff --git a/src/ops/dispatch_combine/low_latency_async.cpp b/src/ops/dispatch_combine/low_latency_async.cpp index 440d3e81c..5a72b993a 100644 --- a/src/ops/dispatch_combine/low_latency_async.cpp +++ b/src/ops/dispatch_combine/low_latency_async.cpp @@ -25,6 +25,7 @@ #include #include #include + #include #include "mori/core/core.hpp" @@ -60,9 +61,9 @@ __global__ void EpDispatchLowLatencyAsyncSend(EpDispatchCombineArgs args) { condition = destPe == (args.tokenIndices[srcTokId * config.numExpertPerToken + laneId] / config.numExpertPerRank); } - if (__any(condition)) { + if (__any(condition) || !core::IsRankActive(args.activeRanks, destPe)) { // Indicate that this token is already sent to the destination PE by setting an overflow - // token index + // token index, or the destination PE is inactive if (laneId == 0) args.dispDestTokIdMap[i] = config.worldSize * config.MaxNumTokensToSendPerRank(); continue; @@ -102,6 +103,7 @@ __global__ void EpDispatchLowLatencyAsyncSend(EpDispatchCombineArgs args) { uint64_t* recvTokenNums = args.recvTokenNumMemObj->template GetAs(); for (int destPe = blockId; destPe < npes; destPe += blockNum) { + if (!core::IsRankActive(args.activeRanks, destPe)) continue; for (int qpId = warpId; qpId < config.numQpPerPe; qpId += warpNum) { if (laneId == 0) shmem::ShmemUint32WaitUntilEquals(args.dispatchGridBarrier, globalWarpNum); int tokenNum = core::AtomicLoadRelaxed(args.destPeTokenCounter + destPe); @@ -134,21 +136,24 @@ __global__ void EpDispatchLowLatencyAsyncRecv(EpDispatchCombineArgs args) { int blocksPerPe = blockNum / npes; int destPe = blockId / blocksPerPe; + const bool peerActive = core::IsRankActive(args.activeRanks, destPe); // TODO(ditian12): index value is wrong when signal completion at send phase, hence we signal at // recv phase as a workaround at the cost of extra latency, we should still investigate the reason if ((blockId % blocksPerPe) == 0) { for (int qpId = warpId; qpId < config.numQpPerPe; qpId += warpNum) { if (laneId == 0) { - shmem::ShmemQuietThread(destPe, qpId); - int tokenNum = core::AtomicLoadRelaxed(args.destPeTokenCounter + destPe); - // TODO(ditian12): send atomic op right after quiet lead to hang issue, need to investigate - // shmem::ShmemAtomicTypeNonFetchThread( - // args.recvTokenNumMemObj, (myPe * config.numQpPerPe + qpId) * sizeof(uint64_t), - // static_cast(tokenNum + 1), core::AMO_ADD, destPe, qpId); - shmem::ShmemPutUint64ImmNbiThread(args.recvTokenNumMemObj, - (myPe * config.numQpPerPe + qpId) * sizeof(uint64_t), - static_cast(tokenNum + 1), destPe, qpId); + if (peerActive) { + shmem::ShmemQuietThread(destPe, qpId); + int tokenNum = core::AtomicLoadRelaxed(args.destPeTokenCounter + destPe); + // TODO(ditian12): send atomic op right after quiet lead to hang issue, need to + // investigate shmem::ShmemAtomicTypeNonFetchThread( + // args.recvTokenNumMemObj, (myPe * config.numQpPerPe + qpId) * sizeof(uint64_t), + // static_cast(tokenNum + 1), core::AMO_ADD, destPe, qpId); + shmem::ShmemPutUint64ImmNbiThread(args.recvTokenNumMemObj, + (myPe * config.numQpPerPe + qpId) * sizeof(uint64_t), + static_cast(tokenNum + 1), destPe, qpId); + } } } } @@ -156,9 +161,14 @@ __global__ void EpDispatchLowLatencyAsyncRecv(EpDispatchCombineArgs args) { uint64_t* recvTokenNums = args.recvTokenNumMemObj->template GetAs(); uint64_t recvTokenNum = 0; if (laneId < config.numQpPerPe) { - recvTokenNum = shmem::ShmemUint64WaitUntilGreaterThan( - recvTokenNums + destPe * config.numQpPerPe + laneId, 0) - - 1; + if (peerActive) { + uint64_t got = core::WaitUntilGreaterThanOrTimeoutSystem( + recvTokenNums + destPe * config.numQpPerPe + laneId, uint64_t{0}, args.timeoutTicks, + args.activeRanks, destPe); + recvTokenNum = (got > 0) ? (got - 1) : 0; + } else { + recvTokenNum = 0; + } } recvTokenNum = __shfl(recvTokenNum, 0); @@ -245,10 +255,9 @@ __global__ void EpCombineLowLatencyAsyncSend(EpDispatchCombineArgs args) { reinterpret_cast(stagingPtr + stagingTokId * tokHiddenBytes), args.inpTokenBuf + tokenId * config.hiddenDim, config.hiddenDim, laneId); } else { - core::WarpCopy(stagingPtr + stagingTokId * tokHiddenBytes, - reinterpret_cast(args.inpTokenBuf) + - tokenId * tokHiddenBytes, - tokHiddenBytes); + core::WarpCopy( + stagingPtr + stagingTokId * tokHiddenBytes, + reinterpret_cast(args.inpTokenBuf) + tokenId * tokHiddenBytes, tokHiddenBytes); } } if (laneId == 0) { @@ -257,6 +266,7 @@ __global__ void EpCombineLowLatencyAsyncSend(EpDispatchCombineArgs args) { uint64_t* recvTokenNums = args.recvTokenNumMemObj->template GetAs(); for (int destPe = blockId; destPe < npes; destPe += blockNum) { + if (!core::IsRankActive(args.activeRanks, destPe)) continue; for (int qpId = warpId; qpId < config.numQpPerPe; qpId += warpNum) { int tokenNum = 0; if (laneId == 0) { @@ -293,6 +303,7 @@ __global__ void EpCombineLowLatencyAsyncRecv(EpDispatchCombineArgs args) { "Fp8 direct cast combine currently only supports bf16 input"); for (int destPe = blockId; destPe < npes; destPe += blockNum) { + if (!core::IsRankActive(args.activeRanks, destPe)) continue; for (int qpId = warpId; qpId < config.numQpPerPe; qpId += warpNum) { if (laneId == 0) { shmem::ShmemQuietThread(destPe, qpId); @@ -308,11 +319,15 @@ __global__ void EpCombineLowLatencyAsyncRecv(EpDispatchCombineArgs args) { } for (int destPe = laneId; destPe < npes; destPe += warpSize) { + if (!core::IsRankActive(args.activeRanks, destPe)) continue; uint64_t barrierFlag = args.crossDeviceBarrierFlag[0]; - for (int i = 0; i < config.numQpPerPe; i++) - shmem::ShmemUint64WaitUntilEquals(args.crossDeviceBarrierMemObj->template GetAs() + - destPe * config.numQpPerPe + i, - barrierFlag); + for (int i = 0; i < config.numQpPerPe; i++) { + uint64_t* addr = args.crossDeviceBarrierMemObj->template GetAs() + + destPe * config.numQpPerPe + i; + const bool ok = core::WaitUntilEqualsOrTimeoutSystem(addr, barrierFlag, args.timeoutTicks, + args.activeRanks, destPe); + if (!ok) core::AtomicStoreRelaxedSystem(addr, barrierFlag); + } } extern __shared__ char sharedMem[]; @@ -335,22 +350,21 @@ __global__ void EpCombineLowLatencyAsyncRecv(EpDispatchCombineArgs args) { index_t destTokId = args.dispDestTokIdMap[tokenId * config.numExpertPerToken + j]; index_t destPe = destTokId / config.MaxNumTokensToSendPerRank(); - TokT* stagingPtr = - (destPe != myPe) ? args.shmemCombineInpTokMemObj->template GetAs() - : args.shmemStagingTokMemObj->template GetAs(); - if (destPe < npes) { + TokT* stagingPtr = (destPe != myPe) ? args.shmemCombineInpTokMemObj->template GetAs() + : args.shmemStagingTokMemObj->template GetAs(); + if ((destPe < npes) && core::IsRankActive(args.activeRanks, destPe)) { srcPtrs[j] = stagingPtr + destTokId * config.hiddenDim + hiddenDimOffset; } else { srcPtrs[j] = nullptr; } } - T* outPtr = args.shmemCombineOutTokMemObj->template GetAs() + - tokenId * config.hiddenDim + hiddenDimOffset; + T* outPtr = args.shmemCombineOutTokMemObj->template GetAs() + tokenId * config.hiddenDim + + hiddenDimOffset; if constexpr (UseFp8DirectCast) { - core::WarpAccumCombineInternalFp8ToBf16( - outPtr, reinterpret_cast(srcPtrs), config.numExpertPerToken, laneId, - hiddenDimSize); + core::WarpAccumCombineInternalFp8ToBf16(outPtr, + reinterpret_cast(srcPtrs), + config.numExpertPerToken, laneId, hiddenDimSize); } else { core::WarpAccum(outPtr, srcPtrs, nullptr, config.numExpertPerToken, hiddenDimSize); } @@ -391,29 +405,27 @@ __global__ void EpCombineLowLatencyAsyncRecv(EpDispatchCombineArgs args) { #endif // Macro to instantiate async kernels for all data types -#define INSTANTIATE_ASYNC_KERNEL(KernelName) \ - template __global__ void KernelName(EpDispatchCombineArgs \ - args); \ - MORI_FP8_FNUZ(template __global__ void KernelName<__hip_fp8_e4m3_fnuz>( \ - EpDispatchCombineArgs<__hip_fp8_e4m3_fnuz> args);) \ - MORI_FP8_OCP(template __global__ void KernelName<__hip_fp8_e4m3>( \ - EpDispatchCombineArgs<__hip_fp8_e4m3> args);) \ - template __global__ void KernelName(EpDispatchCombineArgs \ - args); \ +#define INSTANTIATE_ASYNC_KERNEL(KernelName) \ + template __global__ void KernelName(EpDispatchCombineArgs args); \ + MORI_FP8_FNUZ(template __global__ void KernelName<__hip_fp8_e4m3_fnuz>( \ + EpDispatchCombineArgs<__hip_fp8_e4m3_fnuz> args);) \ + MORI_FP8_OCP(template __global__ void KernelName<__hip_fp8_e4m3>( \ + EpDispatchCombineArgs<__hip_fp8_e4m3> args);) \ + template __global__ void KernelName( \ + EpDispatchCombineArgs args); \ template __global__ void KernelName(EpDispatchCombineArgs args); // Macro to instantiate async combine kernels (includes optional bf16->fp8 direct-cast path) -#define INSTANTIATE_ASYNC_COMBINE_KERNEL(KernelName) \ - template __global__ void KernelName(EpDispatchCombineArgs \ - args); \ - MORI_FP8_ANY(template __global__ void KernelName( \ - EpDispatchCombineArgs args);) \ - MORI_FP8_FNUZ(template __global__ void KernelName<__hip_fp8_e4m3_fnuz>( \ - EpDispatchCombineArgs<__hip_fp8_e4m3_fnuz> args);) \ - MORI_FP8_OCP(template __global__ void KernelName<__hip_fp8_e4m3>( \ - EpDispatchCombineArgs<__hip_fp8_e4m3> args);) \ - template __global__ void KernelName(EpDispatchCombineArgs \ - args); \ +#define INSTANTIATE_ASYNC_COMBINE_KERNEL(KernelName) \ + template __global__ void KernelName(EpDispatchCombineArgs args); \ + MORI_FP8_ANY(template __global__ void KernelName( \ + EpDispatchCombineArgs args);) \ + MORI_FP8_FNUZ(template __global__ void KernelName<__hip_fp8_e4m3_fnuz>( \ + EpDispatchCombineArgs<__hip_fp8_e4m3_fnuz> args);) \ + MORI_FP8_OCP(template __global__ void KernelName<__hip_fp8_e4m3>( \ + EpDispatchCombineArgs<__hip_fp8_e4m3> args);) \ + template __global__ void KernelName( \ + EpDispatchCombineArgs args); \ template __global__ void KernelName(EpDispatchCombineArgs args); INSTANTIATE_ASYNC_KERNEL(EpDispatchLowLatencyAsyncSend) diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index 92544af12..213a3eb91 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -48,12 +48,45 @@ /* ---------------------------------------------------------------------------------------------- */ namespace { +void MaybeUpdateElasticState(mori::moe::EpDispatchCombineHandle& handle, + const std::optional& activeRanks, + const std::optional& timeoutUs) { + int32_t* activeRanksPtr = handle.activeRanks; + int64_t timeoutTicks = handle.timeoutTicks; + + if (activeRanks.has_value()) { + TORCH_CHECK(activeRanks->is_cuda(), "active_ranks must be a CUDA tensor"); + TORCH_CHECK(activeRanks->is_contiguous(), "active_ranks must be contiguous"); + TORCH_CHECK(activeRanks->dim() == 1, "active_ranks must be a 1D tensor"); + TORCH_CHECK(activeRanks->scalar_type() == torch::kInt32, "active_ranks must be int32"); + TORCH_CHECK(activeRanks->numel() == handle.config.worldSize, "active_ranks must have shape (", + handle.config.worldSize, ")"); + activeRanksPtr = activeRanks->data_ptr(); + } + + if (timeoutUs.has_value()) { + if (timeoutUs.value() < 0) { + timeoutTicks = -1; + } else { + timeoutTicks = static_cast(handle.wallClockRateKHz) * timeoutUs.value() / 1000; + } + } + + if (activeRanks.has_value() || timeoutUs.has_value()) { + handle.SetElasticState(activeRanksPtr, timeoutTicks); + } +} + std::tuple, std::optional, torch::Tensor, torch::Tensor> LaunchDispatch(mori::moe::EpDispatchCombineHandle& handle, int kernelType, const torch::Tensor& input, const std::optional& weights, const std::optional& scales, const torch::Tensor& topkIds, - int blockNum = -1, int rdmaBlockNum = -1, int warpPerBlock = -1) { + int blockNum = -1, int rdmaBlockNum = -1, int warpPerBlock = -1, + const std::optional& activeRanks = std::nullopt, + const std::optional& timeoutUs = std::nullopt) { + MaybeUpdateElasticState(handle, activeRanks, timeoutUs); + TORCH_CHECK(input.is_contiguous(), "dispatch input must be contiguous"); TORCH_CHECK(topkIds.is_contiguous(), "dispatch topkIds must be contiguous"); const int hiddenDim = static_cast(input.size(1)); @@ -74,8 +107,8 @@ LaunchDispatch(mori::moe::EpDispatchCombineHandle& handle, int kernelType, if (scales.has_value() && (handle.config.scaleDim > 0)) { TORCH_CHECK(scales->is_contiguous(), "dispatch scales must be contiguous"); TORCH_CHECK(scales->element_size() == handle.config.scaleTypeSize, - "dispatch scales element size mismatch, expected ", - handle.config.scaleTypeSize, ", got ", scales->element_size()); + "dispatch scales element size mismatch, expected ", handle.config.scaleTypeSize, + ", got ", scales->element_size()); scalePtr = reinterpret_cast(scales->data_ptr()); } @@ -85,10 +118,9 @@ LaunchDispatch(mori::moe::EpDispatchCombineHandle& handle, int kernelType, handle.LaunchDispatch((mori::moe::KernelType)kernelType, blockNum, rdmaBlockNum, warpPerBlock, at::cuda::getCurrentHIPStream(), hiddenDim); - torch::Tensor out = - torch::from_blob(handle.shmemDispatchOutTokMemObj->Get(), - {handle.config.MaxNumTokensToRecv(), hiddenDim}, - torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA)); + torch::Tensor out = torch::from_blob( + handle.shmemDispatchOutTokMemObj->Get(), {handle.config.MaxNumTokensToRecv(), hiddenDim}, + torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA)); torch::Tensor outWeights = torch::from_blob( handle.shmemDispatchOutWeightsMemObj->Get(), @@ -121,7 +153,11 @@ LaunchDispatch(mori::moe::EpDispatchCombineHandle& handle, int kernelType, std::tuple> LaunchCombine( mori::moe::EpDispatchCombineHandle& handle, int kernelType, const torch::Tensor& input, const std::optional& weights, const torch::Tensor& topkIds, int blockNum = -1, - int rdmaBlockNum = -1, int warpPerBlock = -1, int useExternalInpBuf = -1) { + int rdmaBlockNum = -1, int warpPerBlock = -1, int useExternalInpBuf = -1, + const std::optional& activeRanks = std::nullopt, + const std::optional& timeoutUs = std::nullopt) { + MaybeUpdateElasticState(handle, activeRanks, timeoutUs); + TORCH_CHECK(input.is_contiguous(), "combine input must be contiguous"); TORCH_CHECK(topkIds.is_contiguous(), "combine topkIds must be contiguous"); const int hiddenDim = static_cast(input.size(1)); @@ -142,9 +178,8 @@ std::tuple> LaunchCombine( useExternalInpBuf, at::cuda::getCurrentHIPStream(), hiddenDim); auto options = torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA); - torch::Tensor out = - torch::from_blob(handle.shmemCombineOutTokMemObj->Get(), - {handle.config.maxNumInpTokenPerRank, hiddenDim}, options); + torch::Tensor out = torch::from_blob(handle.shmemCombineOutTokMemObj->Get(), + {handle.config.maxNumInpTokenPerRank, hiddenDim}, options); std::optional outWeights{std::nullopt}; if (weightsPtr) { @@ -259,9 +294,8 @@ std::tuple> LaunchCombineForStandard // Get output tensor from shmem buffer auto options = torch::TensorOptions().dtype(expertOutput.scalar_type()).device(torch::kCUDA); - torch::Tensor out = - torch::from_blob(handle.shmemCombineOutTokMemObj->Get(), - {handle.config.maxNumInpTokenPerRank, hiddenDim}, options); + torch::Tensor out = torch::from_blob(handle.shmemCombineOutTokMemObj->Get(), + {handle.config.maxNumInpTokenPerRank, hiddenDim}, options); std::optional outWeights{std::nullopt}; // TODO: do not support weights for standard MoE now @@ -347,23 +381,29 @@ torch::Tensor ConvertCombineInput(mori::moe::EpDispatchCombineHandle& handle, {handle.config.MaxNumTokensToRecv(), hidden}, options); // Note: packedRecvLayoutRange is not used in current implementation (passed as nullptr) - handle.LaunchConvertCombineInputKernel( - packedRecvX.data_ptr(), packedRecvSrcInfo.data_ptr(), nullptr, combineInput.data_ptr(), - handle.shmemCombineInpTokMemObj, blockNum, warpPerBlock, at::cuda::getCurrentHIPStream(), - hidden); + handle.LaunchConvertCombineInputKernel(packedRecvX.data_ptr(), packedRecvSrcInfo.data_ptr(), + nullptr, combineInput.data_ptr(), + handle.shmemCombineInpTokMemObj, blockNum, warpPerBlock, + at::cuda::getCurrentHIPStream(), hidden); return combineInput; } #endif // ENABLE_STANDARD_MOE_ADAPT void LaunchDispatchRecv(mori::moe::EpDispatchCombineHandle& handle, int kernelType, - int blockNum = -1, int warpPerBlock = -1) { + int blockNum = -1, int warpPerBlock = -1, + const std::optional& activeRanks = std::nullopt, + const std::optional& timeoutUs = std::nullopt) { + MaybeUpdateElasticState(handle, activeRanks, timeoutUs); handle.LaunchDispatchRecv((mori::moe::KernelType)kernelType, blockNum, warpPerBlock, at::cuda::getCurrentHIPStream()); } void LaunchCombineRecv(mori::moe::EpDispatchCombineHandle& handle, int kernelType, - int blockNum = -1, int warpPerBlock = -1) { + int blockNum = -1, int warpPerBlock = -1, + const std::optional& activeRanks = std::nullopt, + const std::optional& timeoutUs = std::nullopt) { + MaybeUpdateElasticState(handle, activeRanks, timeoutUs); handle.LaunchCombineRecv((mori::moe::KernelType)kernelType, blockNum, warpPerBlock, at::cuda::getCurrentHIPStream()); } @@ -407,10 +447,9 @@ torch::Tensor GetRegisteredCombineInputBuffer(mori::moe::EpDispatchCombineHandle TORCH_CHECK(actualHiddenDim > 0, "registered combine input hidden dim must be > 0"); TORCH_CHECK(actualHiddenDim <= handle.config.hiddenDim, "requested hidden dim ", actualHiddenDim, " exceeds config.hidden_dim ", handle.config.hiddenDim); - torch::Tensor out = - torch::from_blob(handle.shmemCombineInpTokMemObj->Get(), - {handle.config.MaxNumTokensToRecv(), actualHiddenDim}, - torch::TensorOptions().dtype(scalarType).device(torch::kCUDA)); + torch::Tensor out = torch::from_blob( + handle.shmemCombineInpTokMemObj->Get(), {handle.config.MaxNumTokensToRecv(), actualHiddenDim}, + torch::TensorOptions().dtype(scalarType).device(torch::kCUDA)); return out; } @@ -430,7 +469,7 @@ torch::Tensor GetDebugTimeOffset(mori::moe::EpDispatchCombineHandle& handle) { } #endif -int GetCurDeviceWallClockFreqMhz() { return mori::GetCurDeviceWallClockFreqMhz(); } +int GetCurDeviceWallClockFreqKHz() { return mori::GetCurDeviceWallClockFreqKHz(); } void DeclareEpDispatchCombineHandle(pybind11::module& m) { std::string className = std::string("EpDispatchCombineHandle"); @@ -644,8 +683,8 @@ void RegisterMoriOps(py::module_& m) { DeclareEpDispatchCombineHandle(m); - m.def("get_cur_device_wall_clock_freq_mhz", &GetCurDeviceWallClockFreqMhz, - "Returns clock frequency of current device's wall clock"); + m.def("get_cur_device_wall_clock_freq_khz", &GetCurDeviceWallClockFreqKHz, + "Returns clock frequency of current device's wall clock in KHz"); m.def("cast", &Cast, "cast a tensor from type A to type B"); } diff --git a/tests/python/ops/test_dispatch_combine.py b/tests/python/ops/test_dispatch_combine.py index 53da9b108..22fc70b0d 100644 --- a/tests/python/ops/test_dispatch_combine.py +++ b/tests/python/ops/test_dispatch_combine.py @@ -339,40 +339,174 @@ def _test_dispatch_combine( test_case.run_test_once(op, test_data) +def _test_dispatch_combine_elastic( + rank, + world_size, + data_type, + hidden_dim, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, + use_external_inp_buf, + timeout_us, +): + config = mori.ops.EpDispatchCombineConfig( + data_type=data_type, + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim // 2 if _is_fp4x2_dtype(data_type) else hidden_dim, + scale_dim=0, + scale_type_size=4, + max_num_inp_token_per_rank=max_num_inp_token_per_rank, + num_experts_per_rank=num_experts_per_rank, + num_experts_per_token=num_experts_per_token, + max_token_type_size=4, + block_num=40, + warp_num_per_block=8, + use_external_inp_buf=use_external_inp_buf, + ) + op = mori.ops.EpDispatchCombineOp(config) + test_case = EpDispatchCombineTestCase(config) + test_data = test_case.gen_test_data(use_max_token_num=True) + + ( + all_rank_num_token, + all_rank_indices, + all_rank_input, + all_rank_weights, + _all_rank_scales, + ) = test_data + + inactive_rank = world_size - 1 + # Initialize all ranks as active. The C++ timeout mechanism will set + # active_ranks[inactive_rank] to 0 when it detects unresponsiveness! + active_ranks = torch.ones((world_size,), dtype=torch.int32, device=test_case.device) + + test_case.sync() + + # Force every token to route to one expert on every rank (K == world_size), + # so masking out a rank changes the expected combine result. + per_rank_expert0 = ( + torch.arange(world_size, device=test_case.device, dtype=torch.int32) + * num_experts_per_rank + ) + indices_template = per_rank_expert0.view(1, world_size) + for r in range(world_size): + all_rank_indices[r] = indices_template.expand( + all_rank_num_token[r].item(), -1 + ).contiguous() + + is_active = rank != inactive_rank + + if is_active: + ( + dispatch_output, + dispatch_weights, + _dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + ) = op.dispatch( + all_rank_input[rank], + all_rank_weights[rank], + None, + all_rank_indices[rank], + active_ranks=active_ranks, + timeout_us=timeout_us, + ) + else: + # Inactive rank does not call dispatch, simulating complete unresponsiveness. + pass + + test_case.sync() + + if is_active: + # Check that the timeout successfully marked the unresponsive rank as inactive directly in the tensor + assert int(active_ranks[inactive_rank].item()) == 0 + + # Validate we never receive tokens from the inactive source rank. + src_token_pos = op.get_dispatch_src_token_pos() + for pos in src_token_pos: + src_rank = int(pos) // max_num_inp_token_per_rank + assert src_rank != inactive_rank + assert int(dispatch_recv_num_token[0].item()) == int(src_token_pos.numel()) + + total_recv_num_token = int(dispatch_recv_num_token[0].item()) + combine_input = dispatch_output + if not use_external_inp_buf: + combine_input = op.get_registered_combine_input_buffer(config.data_type) + combine_input[:total_recv_num_token, :].copy_( + dispatch_output[:total_recv_num_token, :] + ) + + test_case.sync() + + if is_active: + combine_output, combine_output_weight = op.combine( + combine_input, + dispatch_weights, + dispatch_indices, + call_reset=False, + active_ranks=active_ranks, + timeout_us=timeout_us, + ) + + test_case.sync() + + expected_unique_pes = world_size - 1 + if not is_active or _is_fp4x2_dtype(config.data_type): + return None + + for i in range(all_rank_num_token[rank]): + got, expected = combine_output[i], ( + all_rank_input[rank][i].to(torch.float32) * expected_unique_pes + ).to(config.data_type) + assert torch.allclose(got.float(), expected.float(), atol=1e-2, rtol=1e-2) + + if combine_output_weight is not None: + got_weight, expected_weight = ( + combine_output_weight[i], + all_rank_weights[rank][i] * expected_unique_pes, + ) + assert torch.allclose(got_weight, expected_weight, atol=1e-5, rtol=1e-5) + + # TODO: create a sub process group so that we can test worlds size < 8 @pytest.mark.parametrize("world_size", (8,)) -@pytest.mark.parametrize("data_type", ( - [ - torch.bfloat16, - pytest.param( - torch.float8_e4m3fnuz, - marks=pytest.mark.skipif( - not data_type_supported(torch.float8_e4m3fnuz), - reason="Skip float8_e4m3fnuz, it is not supported", - ), - ), - pytest.param( - torch.float8_e4m3fn, - marks=pytest.mark.skipif( - not data_type_supported(torch.float8_e4m3fn), - reason="Skip float8_e4m3fn, it is not supported", - ), - ), - ] - + ( +@pytest.mark.parametrize( + "data_type", + ( [ + torch.bfloat16, pytest.param( - TORCH_FLOAT4_E2M1FN_X2, + torch.float8_e4m3fnuz, marks=pytest.mark.skipif( - not data_type_supported(TORCH_FLOAT4_E2M1FN_X2), - reason="Skip float4_e2m1fn_x2, it is not supported", + not data_type_supported(torch.float8_e4m3fnuz), + reason="Skip float8_e4m3fnuz, it is not supported", ), - ) + ), + pytest.param( + torch.float8_e4m3fn, + marks=pytest.mark.skipif( + not data_type_supported(torch.float8_e4m3fn), + reason="Skip float8_e4m3fn, it is not supported", + ), + ), ] - if TORCH_FLOAT4_E2M1FN_X2 is not None - else [] - ) -)) + + ( + [ + pytest.param( + TORCH_FLOAT4_E2M1FN_X2, + marks=pytest.mark.skipif( + not data_type_supported(TORCH_FLOAT4_E2M1FN_X2), + reason="Skip float4_e2m1fn_x2, it is not supported", + ), + ) + ] + if TORCH_FLOAT4_E2M1FN_X2 is not None + else [] + ) + ), +) @pytest.mark.parametrize("hidden_dim", (7168, 4096)) @pytest.mark.parametrize("scale_dim", (0, 32)) @pytest.mark.parametrize("scale_type_size", (1, 4)) @@ -430,3 +564,51 @@ def test_dispatch_combine( for result in results: if result is not None: pytest.assume(False, result) + + +@pytest.mark.parametrize("world_size", (8,)) +@pytest.mark.parametrize("data_type", (torch.bfloat16,)) +@pytest.mark.parametrize("hidden_dim", (4096,)) +@pytest.mark.parametrize("max_num_inp_token_per_rank", (16,)) +@pytest.mark.parametrize("num_experts_per_rank", (32,)) +@pytest.mark.parametrize("num_experts_per_token", (8,)) +@pytest.mark.parametrize("use_external_inp_buf", (True, False)) +def test_dispatch_combine_elastic_ep( + torch_dist_process_manager, + world_size, + data_type, + hidden_dim, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, + use_external_inp_buf, +): + timeout_us = 500_000 + for i in range(world_size): + torch_dist_process_manager.task_queue.put( + ( + _test_dispatch_combine_elastic, + [ + world_size, + data_type, + hidden_dim, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, + use_external_inp_buf, + timeout_us, + ], + ) + ) + + results = [] + for i in range(world_size): + ( + rank, + result, + ) = torch_dist_process_manager.result_queue.get() + results.append(result) + + for result in results: + if result is not None: + pytest.assume(False, result)