Skip to content

Optimize non-circular buffered TMA loads #5858

@liqiangxl

Description

@liqiangxl

Issue: Each non-circular buffered TMA load is handled separately including 8 steps: mbarrier alloc, init, sync, setExpectTx, TMA load, mbarrier wait, sync, mbarrier invalid:

  uint64_t* T16 = reinterpret_cast<uint64_t*>(array + smem_offset + 4224);
  mbarrier::init(toSmem(T16), 1U);
  __syncthreads();
  if ((Hopper::electSync(4294967295U) && b16)) {
    uint64_t i18;
    i18 = mbarrier::arriveExpectTX(toSmem(T16), 4096U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, a7, toSmem(T16) }), toSmem(T10));
    mbarrier::wait(toSmem(T16), i18);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T16));

When there are N non-circular buffered TMA loads, these 8 steps are duplicated N times, which is inefficient. With 4 inputs, the achieved bandwidth is 84% SOL. The corresponding cuda code is:

  uint64_t* T16 = reinterpret_cast<uint64_t*>(array + smem_offset + 4224);
  mbarrier::init(toSmem(T16), 1U);
  __syncthreads();
  if ((Hopper::electSync(4294967295U) && b16)) {
    uint64_t i18;
    i18 = mbarrier::arriveExpectTX(toSmem(T16), 4096U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, a7, toSmem(T16) }), toSmem(T10));
    mbarrier::wait(toSmem(T16), i18);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T16));
  uint64_t* T17 = reinterpret_cast<uint64_t*>(array + smem_offset + 4096);
  mbarrier::init(toSmem(T17), 1U);
  __syncthreads();
  if ((Hopper::electSync(4294967295U) && b16)) {
    uint64_t i19;
    i19 = mbarrier::arriveExpectTX(toSmem(T17), 4096U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, a7, toSmem(T17) }), toSmem(T9));
    mbarrier::wait(toSmem(T17), i19);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T17));
  uint64_t* T18 = reinterpret_cast<uint64_t*>(array + smem_offset + 12544);
  mbarrier::init(toSmem(T18), 1U);
  __syncthreads();
  if ((Hopper::electSync(4294967295U) && b16)) {
    uint64_t i20;
    i20 = mbarrier::arriveExpectTX(toSmem(T18), 4096U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr9, a7, toSmem(T18) }), toSmem(T8));
    mbarrier::wait(toSmem(T18), i20);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T18));
  uint64_t* T19 = reinterpret_cast<uint64_t*>(array + smem_offset + 12416);
  mbarrier::init(toSmem(T19), 1U);
  __syncthreads();
  if ((Hopper::electSync(4294967295U) && b16)) {
    uint64_t i21;
    i21 = mbarrier::arriveExpectTX(toSmem(T19), 4096U);
    Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr10, a7, toSmem(T19) }), toSmem(T7));
    mbarrier::wait(toSmem(T19), i21);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T19));

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions