Skip to content

EpDispatchCombineOp crashes with SIGSEGV / OOM on MI355X (sglang-0.5.9-rocm720-mi35x-mori-0227-2 container) #210

@sunway513

Description

@sunway513

Description

Attempting to run Mori EP dispatch/combine microbenchmark on MI355X results in either SIGSEGV or OOM errors depending on heap size configuration.

Environment

  • Container: rocm/sgl-dev:sglang-0.5.9-rocm720-mi35x-mori-0227-2
  • GPU: 8x AMD Instinct MI355X
  • Mori: installed at /sgl-workspace/mori/python/mori/

Reproduction

#!/usr/bin/env python3
import os, torch, torch.distributed as dist, torch.multiprocessing as mp
import mori

os.environ["MORI_SHMEM_HEAP_SIZE"] = "8G"

def worker(rank, world_size):
    os.environ.update({"MASTER_ADDR":"localhost","MASTER_PORT":"29850"})
    torch.cuda.set_device(rank)
    device = torch.device("cuda", rank)
    dist.init_process_group(
        backend="cpu:gloo,cuda:nccl", rank=rank,
        world_size=world_size, device_id=device,
    )
    world_group = torch.distributed.group.WORLD
    torch._C._distributed_c10d._register_process_group("default", world_group)
    mori.shmem.shmem_torch_process_group_init("default")

    config = mori.ops.EpDispatchCombineConfig(
        data_type=torch.bfloat16, rank=rank, world_size=world_size,
        hidden_dim=7168, scale_dim=0,
        scale_type_size=torch.tensor([], dtype=torch.float8_e4m3fnuz).element_size(),
        max_token_type_size=2,
        max_num_inp_token_per_rank=4096,
        num_experts_per_rank=32,  # 256 experts / 8 GPUs
        num_experts_per_token=8,
    )
    op = mori.ops.EpDispatchCombineOp(config)

    x = torch.randn(128, 7168, dtype=torch.bfloat16, device=device)
    topk_ids = torch.randint(0, 256, (128, 8), device=device, dtype=torch.int32)
    topk_weights = torch.randn(128, 8, dtype=torch.float32, device=device).softmax(dim=-1)

    out = op.dispatch(x, topk_weights, torch.ones(128,1,dtype=torch.float32,device=device), topk_ids)
    print(f"rank {rank}: dispatch ok")

    mori.shmem.shmem_finalize()
    dist.destroy_process_group()

if __name__ == "__main__":
    mp.spawn(worker, args=(8,), nprocs=8, join=True)

Error

With MORI_SHMEM_HEAP_SIZE=4G:

[application] [error] Out of heap memory! Requested: 1879048192 bytes (aligned), Current heap size: 4294967296 bytes.
[shmem] [error] Out of static heap memory! Requested: 1879048192 bytes.
[dispatch_combine.cpp:79] hip failed with invalid argument

With MORI_SHMEM_HEAP_SIZE=8G:

process 5 terminated with signal SIGSEGV

Notes

  • The examples/ops/dispatch_combine/test_dispatch_combine.py test script runs correctly for correctness testing (token routing), but does not include any performance benchmarking output.
  • The dispatch/combine API (op.dispatch(input, weights, scales, indices)) differs from the Python-level test which uses a different wrapper class. Documentation on the correct C++ binding API for benchmarking would be helpful.

Expected

Mori EP dispatch/combine should run without SIGSEGV at MORI_SHMEM_HEAP_SIZE=8G for DeepSeek-V3 config (256 experts, 7168 hidden, top-8, 4096 tokens).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions