Skip to content

Conversation

@hodgesds
Copy link

@hodgesds hodgesds commented Feb 6, 2026

Under system saturation, sched_yield has multi-millisecond tail latency (up to 4ms at 100% CPU).

Replace immediate sched_yield with time-based spinning:

  • Spin with CPU pause instruction for configurable duration (default 1µs)
  • Only yield after spin timeout, then reset timer
  • Configurable via environment variables

Changes:

  • utils.h: Add ncclCpuRelax() cross-platform CPU pause helper
  • proxy.cc: Time-based spin in freeOps pool wait loop
  • proxy.cc: Time-based spin in progress loop idle path
  • doca_gpunetio.cpp: Replace yield with pause in service mainloop

Environment variables:

  • NCCL_PROXY_SPIN_TIME_NS: freeOps wait spin duration (default 1000)
  • NCCL_PROXY_PROGRESS_SPIN_TIME_NS: progress loop spin duration (default 1000)
  • Set to 0 to restore original always-yield behavior

The pause instruction (~43 cycles on x86) allows hyperthreads to run while avoiding syscall overhead.

Under system saturation, sched_yield has multi-millisecond tail latency
(up to 4ms at 100% CPU).

Replace immediate sched_yield with time-based spinning:
- Spin with CPU pause instruction for configurable duration (default 1µs)
- Only yield after spin timeout, then reset timer
- Configurable via environment variables

Changes:
- utils.h: Add ncclCpuRelax() cross-platform CPU pause helper
- proxy.cc: Time-based spin in freeOps pool wait loop
- proxy.cc: Time-based spin in progress loop idle path
- doca_gpunetio.cpp: Replace yield with pause in service mainloop

Environment variables:
- NCCL_PROXY_SPIN_TIME_NS: freeOps wait spin duration (default 1000)
- NCCL_PROXY_PROGRESS_SPIN_TIME_NS: progress loop spin duration (default 1000)
- Set to 0 to restore original always-yield behavior

The pause instruction (~43 cycles on x86) allows hyperthreads to run
while avoiding syscall overhead.

Signed-off-by: Daniel Hodges <hodgesd@meta.com>
@hodgesds
Copy link
Author

hodgesds commented Feb 6, 2026

NCCL sched_yield Optimization Benchmarks

Test Environment

CPU

  • Processor: 2x AMD EPYC 9654 96-Core Processor
  • Cores: 192 physical cores (384 threads with SMT)
  • Frequency: 2.4 GHz max
  • Architecture: x86_64 (Zen 4, Genoa)
  • Features: AVX-512, SMT-2

NUMA Layout

Node 0: CPUs 0-95, 192-287 (96 cores + 96 hyperthreads)
        Memory: 1.1 TB
Node 1: CPUs 96-191, 288-383 (96 cores + 96 hyperthreads)
        Memory: 1.1 TB

NUMA distances:
       Node0  Node1
Node0:   10     32
Node1:   32     10

GPU Topology

  • GPUs: 8x NVIDIA H100 (96GB HBM3 each)
  • PCIe: Gen5 x16
  • Interconnect: NVLink 4.0
GPU Memory: 97,871 MiB (96 GB) per GPU
PCIe: Gen5 x16 per GPU

NVLink Configuration:
- 18 links per GPU @ 26.562 GB/s each
- Total per GPU: 478 GB/s per direction (~900 GB/s bidirectional)
- Full mesh: All 8 GPUs directly connected via NVLink

GPU NUMA Affinity:
GPU0-3: NUMA node 0, CPUs 0-95,192-287
GPU4-7: NUMA node 1, CPUs 96-191,288-383

Topology Matrix (nvidia-smi topo -m):
        GPU0  GPU1  GPU2  GPU3  GPU4  GPU5  GPU6  GPU7
GPU0     X    NV18  NV18  NV18  NV18  NV18  NV18  NV18
GPU1    NV18   X    NV18  NV18  NV18  NV18  NV18  NV18
GPU2    NV18  NV18   X    NV18  NV18  NV18  NV18  NV18
GPU3    NV18  NV18  NV18   X    NV18  NV18  NV18  NV18
GPU4    NV18  NV18  NV18  NV18   X    NV18  NV18  NV18
GPU5    NV18  NV18  NV18  NV18  NV18   X    NV18  NV18
GPU6    NV18  NV18  NV18  NV18  NV18  NV18   X    NV18
GPU7    NV18  NV18  NV18  NV18  NV18  NV18  NV18   X

NV18 = 18 NVLink connections between GPU pair

Software

  • NCCL Version: 2.29.3+cuda12.9
  • CUDA: 12.9

1. Per-Operation Overhead

Measured cost of individual operations:

Operation Cycles Time Notes
pause 24-36 10-15ns CPU spin-wait hint
clock_gettime 48-72 20-30ns vDSO, very cheap
sched_yield (idle) 264-338 110-140ns 10x more than pause

2. sched_yield Under CPU Load

CPU Load Avg Cycles p99 Cycles Max Cycles Max Time
Idle (0%) 338 360 22K 9µs
50% (192 threads) 686 744 27K 11µs
100% (384 threads) 2,815 936 18.4M 7.7ms
200% (768 threads) 4,104 936 37.7M 15.7ms

Under full CPU load, sched_yield has:

  • 8x higher average latency (338 → 2,815 cycles)
  • Multi-millisecond tail latency (up to 15.7ms)

3. Wake-up Latency Comparison

Time from flag set to consumer thread noticing (cycles):

Wait Mode Avg p50 p99 Max
Spin with pause 611 528 936 102K
Immediate yield (old NCCL) 688 672 1,032 overflow
Timed spin 1µs (new NCCL) 866 816 1,416 17K
Timed spin 5µs 626 552 1,632 42K

Timed spinning provides consistent latency with bounded max.


4. Spin Instruction Comparison

Comparing different spin-wait approaches:

Instruction Per-call (cycles) Wake-up avg Wake-up max
pause 43 751 102K
rep nop 43 603 100K
lfence 15 596 27K
barrier only 5 893 51K
nothing 5 772 141K

pause balances power efficiency, hyperthread fairness, and latency.


5. NCCL AllReduce Benchmark

8x H100 GPUs, 64MB buffers, 500 iterations:

Version Time/iter Bandwidth Notes
Unpatched (25368a7) 0.374 ms 292.18 GB/s Original
Patched 0.375 ms 291.67 GB/s With optimization

Performance is nearly identical on NVLink single host without additional load
due to:

  • NVLink is extremely fast (~900 GB/s)
  • Operations complete quickly, minimal waiting
  • Little contention in the proxy thread pool

The optimization benefits are expected on:

  • Multi-node setups with InfiniBand/RoCE (slower, more waiting)
  • Systems under high CPU load (avoids tail latency)
  • Workloads with many concurrent NCCL communicators

6. Yield Contention Benchmark

Simulating NCCL proxy free-ops pool contention pattern (64 threads, 10 seconds):

Mode Ops/sec Yields/sec Reduction
Immediate yield 3,796,307 14,707,915 baseline
Spin 1µs before yield 1,937,157 14,493,397 ~1.5% fewer yields

This user-space simulation shows modest improvement. Real NCCL workloads with
network latency would show larger reductions.


Configuration

Environment variables for tuning:

# Spin duration for freeOps pool wait (default: 1000ns)
export NCCL_PROXY_SPIN_TIME_NS=1000

# Spin duration for progress loop idle (default: 1000ns)
export NCCL_PROXY_PROGRESS_SPIN_TIME_NS=1000

# Disable spinning (original behavior)
export NCCL_PROXY_SPIN_TIME_NS=0
export NCCL_PROXY_PROGRESS_SPIN_TIME_NS=0


Benchmark Source Code

Latency Microbenchmark (latency_bench.cpp)

Measures per-operation overhead and wake-up latency.

/*
 * Microbenchmark to measure:
 * 1. Per-operation overhead: pause, sched_yield, clock_gettime
 * 2. Wake-up latency: time from flag set to consumer noticing
 *
 * Build:
 *   g++ -O2 -pthread -o latency_bench.cpp
 *
 * Run:
 *   ./latency_bench
 *
 */

#include <atomic>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <thread>
#include <vector>
#include <sched.h>
#include <time.h>
#include <x86intrin.h>  // For __rdtsc

// Number of iterations for averaging
static constexpr int WARMUP_ITERS = 1000;
static constexpr int MEASURE_ITERS = 100000;
static constexpr int LATENCY_ITERS = 10000;

// CPU pause instruction
static inline void cpu_pause() {
#if defined(__x86_64__) || defined(__i386__)
    __asm__ __volatile__("pause" ::: "memory");
#elif defined(__aarch64__)
    __asm__ __volatile__("yield" ::: "memory");
#else
    __asm__ __volatile__("" ::: "memory");
#endif
}

// Get TSC cycles
static inline uint64_t rdtsc() {
    return __rdtsc();
}

// Get monotonic time in nanoseconds
static inline uint64_t clock_nano() {
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);
    return static_cast<uint64_t>(ts.tv_sec) * 1000000000ULL + ts.tv_nsec;
}

// Measure CPU frequency
double measure_cpu_freq_ghz() {
    uint64_t t0 = clock_nano();
    uint64_t c0 = rdtsc();

    // Busy wait for 100ms
    while (clock_nano() - t0 < 100000000ULL) {
        cpu_pause();
    }

    uint64_t t1 = clock_nano();
    uint64_t c1 = rdtsc();

    double elapsed_ns = static_cast<double>(t1 - t0);
    double cycles = static_cast<double>(c1 - c0);
    return cycles / elapsed_ns;  // GHz
}

struct OpStats {
    double min_cycles;
    double max_cycles;
    double avg_cycles;
    double p50_cycles;
    double p99_cycles;
};
OpStats compute_stats(std::vector<uint64_t>& samples) {
    std::sort(samples.begin(), samples.end());

    OpStats stats;
    stats.min_cycles = static_cast<double>(samples.front());
    stats.max_cycles = static_cast<double>(samples.back());

    uint64_t sum = 0;
    for (auto s : samples) sum += s;
    stats.avg_cycles = static_cast<double>(sum) / samples.size();

    stats.p50_cycles = static_cast<double>(samples[samples.size() / 2]);
    stats.p99_cycles = static_cast<double>(samples[samples.size() * 99 / 100]);

    return stats;
}

void measure_pause() {
    printf("\n=== Measuring pause instruction ===\n");

    std::vector<uint64_t> samples;
    samples.reserve(MEASURE_ITERS);

    for (int i = 0; i < WARMUP_ITERS; i++) {
        cpu_pause();
    }

    for (int i = 0; i < MEASURE_ITERS; i++) {
        uint64_t c0 = rdtsc();
        cpu_pause();
        uint64_t c1 = rdtsc();
        samples.push_back(c1 - c0);
    }

    OpStats stats = compute_stats(samples);
    printf("  min: %.0f cycles\n", stats.min_cycles);
    printf("  avg: %.0f cycles\n", stats.avg_cycles);
    printf("  p50: %.0f cycles\n", stats.p50_cycles);
    printf("  p99: %.0f cycles\n", stats.p99_cycles);
    printf("  max: %.0f cycles\n", stats.max_cycles);
}

void measure_clock_gettime() {
    printf("\n=== Measuring clock_gettime(CLOCK_MONOTONIC) ===\n");

    std::vector<uint64_t> samples;
    samples.reserve(MEASURE_ITERS);

    for (int i = 0; i < WARMUP_ITERS; i++) {
        clock_nano();
    }

    for (int i = 0; i < MEASURE_ITERS; i++) {
        uint64_t c0 = rdtsc();
        clock_nano();
        uint64_t c1 = rdtsc();
        samples.push_back(c1 - c0);
    }

    OpStats stats = compute_stats(samples);
    printf("  min: %.0f cycles\n", stats.min_cycles);
    printf("  avg: %.0f cycles\n", stats.avg_cycles);
    printf("  p50: %.0f cycles\n", stats.p50_cycles);
    printf("  p99: %.0f cycles\n", stats.p99_cycles);
    printf("  max: %.0f cycles\n", stats.max_cycles);
}

void measure_sched_yield() {
    printf("\n=== Measuring sched_yield() ===\n");

    std::vector<uint64_t> samples;
    samples.reserve(MEASURE_ITERS);

    for (int i = 0; i < WARMUP_ITERS; i++) {
        sched_yield();
    }

    for (int i = 0; i < MEASURE_ITERS; i++) {
        uint64_t c0 = rdtsc();
        sched_yield();
        uint64_t c1 = rdtsc();
        samples.push_back(c1 - c0);
    }

    OpStats stats = compute_stats(samples);
    printf("  min: %.0f cycles\n", stats.min_cycles);
    printf("  avg: %.0f cycles\n", stats.avg_cycles);
    printf("  p50: %.0f cycles\n", stats.p50_cycles);
    printf("  p99: %.0f cycles\n", stats.p99_cycles);
    printf("  max: %.0f cycles\n", stats.max_cycles);
}

struct alignas(64) SharedState {
    std::atomic<uint64_t> flag{0};
    std::atomic<bool> ready{false};
    std::atomic<bool> done{false};
    uint64_t consumer_noticed;
    char padding[64];
};

enum class WaitMode { SPIN_PAUSE, SPIN_YIELD, SPIN_TIMED };

void consumer_thread(SharedState* state, WaitMode mode, int64_t spin_time_ns) {
    while (!state->done.load(std::memory_order_relaxed)) {
        while (!state->ready.load(std::memory_order_acquire)) {
            cpu_pause();
        }

        uint64_t flag_value = 0;

        if (mode == WaitMode::SPIN_PAUSE) {
            while ((flag_value = state->flag.load(std::memory_order_acquire)) == 0) {
                cpu_pause();
            }
        } else if (mode == WaitMode::SPIN_YIELD) {
            while ((flag_value = state->flag.load(std::memory_order_acquire)) == 0) {
                sched_yield();
            }
        } else {
            uint64_t t0 = clock_nano();
            while ((flag_value = state->flag.load(std::memory_order_acquire)) == 0) {
                if (clock_nano() - t0 < static_cast<uint64_t>(spin_time_ns)) {
                    cpu_pause();
                } else {
                    sched_yield();
                    t0 = clock_nano();
                }
            }
        }

        state->consumer_noticed = rdtsc();
        state->ready.store(false, std::memory_order_release);
    }
}

void measure_wakeup_latency(WaitMode mode, const char* mode_name, int64_t spin_time_ns = 0) {
    printf("\n=== Measuring wake-up latency: %s ===\n", mode_name);

    SharedState state;
    std::thread consumer(consumer_thread, &state, mode, spin_time_ns);

    std::vector<uint64_t> latencies;
    latencies.reserve(LATENCY_ITERS);

    for (int i = 0; i < 100; i++) {
        state.ready.store(true, std::memory_order_release);
        for (int j = 0; j < 100; j++) cpu_pause();
        uint64_t t0 = rdtsc();
        state.flag.store(t0, std::memory_order_release);
        while (state.ready.load(std::memory_order_acquire)) {
            cpu_pause();
        }
        state.flag.store(0, std::memory_order_relaxed);
    }

    for (int i = 0; i < LATENCY_ITERS; i++) {
        state.ready.store(true, std::memory_order_release);
        int delay = (i * 7) % 1000;
        for (int j = 0; j < delay; j++) cpu_pause();
        uint64_t t0 = rdtsc();
        state.flag.store(t0, std::memory_order_release);
        while (state.ready.load(std::memory_order_acquire)) {
            cpu_pause();
        }
        uint64_t latency = state.consumer_noticed - t0;
        latencies.push_back(latency);
        state.flag.store(0, std::memory_order_relaxed);
    }

    state.done.store(true, std::memory_order_release);
    state.ready.store(true, std::memory_order_release);
    state.flag.store(1, std::memory_order_release);
    consumer.join();

    OpStats stats = compute_stats(latencies);
    printf("  min: %.0f cycles\n", stats.min_cycles);
    printf("  avg: %.0f cycles\n", stats.avg_cycles);
    printf("  p50: %.0f cycles\n", stats.p50_cycles);
    printf("  p99: %.0f cycles\n", stats.p99_cycles);
    printf("  max: %.0f cycles\n", stats.max_cycles);
}

std::atomic<bool> g_stress_running{true};

void stress_worker() {
    volatile uint64_t x = 0;
    while (g_stress_running.load(std::memory_order_relaxed)) {
        x = x * 7 + 13;
    }
}

void measure_sched_yield_under_load(int num_stress_threads) {
    printf("\n=== Measuring sched_yield() under load (%d stress threads) ===\n",
           num_stress_threads);

    std::vector<std::thread> stress_threads;
    g_stress_running.store(true);
    for (int i = 0; i < num_stress_threads; i++) {
        stress_threads.emplace_back(stress_worker);
    }

    std::this_thread::sleep_for(std::chrono::milliseconds(100));

    std::vector<uint64_t> samples;
    samples.reserve(MEASURE_ITERS);

    for (int i = 0; i < MEASURE_ITERS; i++) {
        uint64_t c0 = rdtsc();
        sched_yield();
        uint64_t c1 = rdtsc();
        samples.push_back(c1 - c0);
    }

    g_stress_running.store(false);
    for (auto& t : stress_threads) {
        t.join();
    }

    OpStats stats = compute_stats(samples);
    printf("  min: %.0f cycles\n", stats.min_cycles);
    printf("  avg: %.0f cycles\n", stats.avg_cycles);
    printf("  p50: %.0f cycles\n", stats.p50_cycles);
    printf("  p99: %.0f cycles\n", stats.p99_cycles);
    printf("  max: %.0f cycles\n", stats.max_cycles);
}

int main(int argc, char* argv[]) {
    printf("Latency Microbenchmark\n");
    printf("======================\n");

    double cpu_ghz = measure_cpu_freq_ghz();
    printf("\nCPU frequency: %.2f GHz\n", cpu_ghz);
    printf("(1 cycle = %.2f ns)\n", 1.0 / cpu_ghz);

    printf("\n\n### PART 1: Per-operation overhead ###\n");
    measure_pause();
    measure_clock_gettime();
    measure_sched_yield();

    printf("\n\n### PART 2: Wake-up latency ###\n");
    measure_wakeup_latency(WaitMode::SPIN_PAUSE, "spin with pause");
    measure_wakeup_latency(WaitMode::SPIN_YIELD, "immediate yield (old NCCL)");
    measure_wakeup_latency(WaitMode::SPIN_TIMED, "timed spin 1us (new NCCL)", 1000);
    measure_wakeup_latency(WaitMode::SPIN_TIMED, "timed spin 5us", 5000);

    printf("\n\n### PART 3: sched_yield under CPU load ###\n");
    int num_cpus = std::thread::hardware_concurrency();
    printf("System has %d CPUs\n", num_cpus);

    measure_sched_yield_under_load(num_cpus / 2);
    measure_sched_yield_under_load(num_cpus);
    measure_sched_yield_under_load(num_cpus * 2);

    printf("\n\nDone.\n");
    return 0;
}

NCCL AllReduce Benchmark (nccl_yield_bench.cu)

Tests actual NCCL collective operations.

/*
 * NCCL AllReduce Benchmark to reproduce sched_yield contention
 *
 * Build (from NCCL root):
 *   nvcc -O2 -o nccl_yield_bench nccl_yield_bench.cu \
 *        -I./build/include -L./build/lib -lnccl -lcudart -lpthread
 *
 * Run:
 *   LD_LIBRARY_PATH=./build/lib ./nccl_yield_bench
 *
 * Profile:
 *   perf stat -e syscalls:sys_enter_sched_yield \
 *     LD_LIBRARY_PATH=./build/lib ./nccl_yield_bench
 */

#include <cuda_runtime.h>
#include <nccl.h>

#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <thread>
#include <vector>

#define CUDACHECK(cmd) do {                         \
  cudaError_t err = cmd;                            \
  if (err != cudaSuccess) {                         \
    fprintf(stderr, "CUDA error %s:%d '%s'\n",      \
        __FILE__, __LINE__, cudaGetErrorString(err)); \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

#define NCCLCHECK(cmd) do {                         \
  ncclResult_t res = cmd;                           \
  if (res != ncclSuccess) {                         \
    fprintf(stderr, "NCCL error %s:%d '%s'\n",      \
        __FILE__, __LINE__, ncclGetErrorString(res)); \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

struct BenchConfig {
    size_t buffer_size = 64 * 1024 * 1024;  // 64MB per GPU
    int num_iterations = 1000;
    int warmup_iterations = 100;
    bool sync_each_iter = true;
};

void run_benchmark(int num_gpus, const BenchConfig& config) {
    printf("Running NCCL benchmark with %d GPUs\n", num_gpus);
    printf("  Buffer size: %zu MB\n", config.buffer_size / (1024 * 1024));
    printf("  Iterations: %d\n", config.num_iterations);
    printf("  Sync each iteration: %s\n", config.sync_each_iter ? "yes" : "no");

    const char* spin_env = getenv("NCCL_PROXY_SPIN_TIME_NS");
    if (spin_env) {
        printf("  NCCL_PROXY_SPIN_TIME_NS: %s\n", spin_env);
    } else {
        printf("  NCCL_PROXY_SPIN_TIME_NS: (default)\n");
    }
    printf("\n");

    std::vector<ncclComm_t> comms(num_gpus);
    std::vector<cudaStream_t> streams(num_gpus);
    std::vector<float*> send_buffers(num_gpus);
    std::vector<float*> recv_buffers(num_gpus);

    ncclUniqueId id;
    NCCLCHECK(ncclGetUniqueId(&id));

    for (int i = 0; i < num_gpus; i++) {
        CUDACHECK(cudaSetDevice(i));
        CUDACHECK(cudaStreamCreate(&streams[i]));
        CUDACHECK(cudaMalloc(&send_buffers[i], config.buffer_size));
        CUDACHECK(cudaMalloc(&recv_buffers[i], config.buffer_size));
        CUDACHECK(cudaMemset(send_buffers[i], i + 1, config.buffer_size));
    }

    NCCLCHECK(ncclGroupStart());
    for (int i = 0; i < num_gpus; i++) {
        CUDACHECK(cudaSetDevice(i));
        NCCLCHECK(ncclCommInitRank(&comms[i], num_gpus, id, i));
    }
    NCCLCHECK(ncclGroupEnd());

    size_t count = config.buffer_size / sizeof(float);
    
    printf("Warming up...\n");
    for (int iter = 0; iter < config.warmup_iterations; iter++) {
        NCCLCHECK(ncclGroupStart());
        for (int i = 0; i < num_gpus; i++) {
            CUDACHECK(cudaSetDevice(i));
            NCCLCHECK(ncclAllReduce(send_buffers[i], recv_buffers[i], count,
                                    ncclFloat, ncclSum, comms[i], streams[i]));
        }
        NCCLCHECK(ncclGroupEnd());

        if (config.sync_each_iter) {
            for (int i = 0; i < num_gpus; i++) {
                CUDACHECK(cudaSetDevice(i));
                CUDACHECK(cudaStreamSynchronize(streams[i]));
            }
        }
    }

    for (int i = 0; i < num_gpus; i++) {
        CUDACHECK(cudaSetDevice(i));
        CUDACHECK(cudaDeviceSynchronize());
    }

    printf("Running benchmark...\n");
    auto start = std::chrono::high_resolution_clock::now();

    for (int iter = 0; iter < config.num_iterations; iter++) {
        NCCLCHECK(ncclGroupStart());
        for (int i = 0; i < num_gpus; i++) {
            CUDACHECK(cudaSetDevice(i));
            NCCLCHECK(ncclAllReduce(send_buffers[i], recv_buffers[i], count,
                                    ncclFloat, ncclSum, comms[i], streams[i]));
        }
        NCCLCHECK(ncclGroupEnd());

        if (config.sync_each_iter) {
            for (int i = 0; i < num_gpus; i++) {
                CUDACHECK(cudaSetDevice(i));
                CUDACHECK(cudaStreamSynchronize(streams[i]));
            }
        }
    }

    for (int i = 0; i < num_gpus; i++) {
        CUDACHECK(cudaSetDevice(i));
        CUDACHECK(cudaDeviceSynchronize());
    }

    auto end = std::chrono::high_resolution_clock::now();
    double elapsed_ms = std::chrono::duration<double, std::milli>(end - start).count();

    double algo_bw = 2.0 * (num_gpus - 1.0) / num_gpus * config.buffer_size;
    double total_bytes = algo_bw * config.num_iterations;
    double bw_gbps = (total_bytes / (elapsed_ms / 1000.0)) / (1024.0 * 1024.0 * 1024.0);

    printf("\nResults:\n");
    printf("  Total time: %.2f ms\n", elapsed_ms);
    printf("  Time per iteration: %.3f ms\n", elapsed_ms / config.num_iterations);
    printf("  Throughput: %.2f iterations/sec\n", config.num_iterations / (elapsed_ms / 1000.0));
    printf("  Algorithm bandwidth: %.2f GB/s\n", bw_gbps);

    for (int i = 0; i < num_gpus; i++) {
        CUDACHECK(cudaSetDevice(i));
        NCCLCHECK(ncclCommDestroy(comms[i]));
        CUDACHECK(cudaStreamDestroy(streams[i]));
        CUDACHECK(cudaFree(send_buffers[i]));
        CUDACHECK(cudaFree(recv_buffers[i]));
    }
}

int main(int argc, char* argv[]) {
    int num_gpus;
    CUDACHECK(cudaGetDeviceCount(&num_gpus));

    if (num_gpus < 1) {
        fprintf(stderr, "No GPUs found\n");
        return 1;
    }

    printf("Found %d GPUs\n", num_gpus);
    for (int i = 0; i < num_gpus; i++) {
        cudaDeviceProp prop;
        CUDACHECK(cudaGetDeviceProperties(&prop, i));
        printf("  GPU %d: %s\n", i, prop.name);
    }
    printf("\n");

    BenchConfig config;

    for (int i = 1; i < argc; i++) {
        if (strncmp(argv[i], "--size=", 7) == 0) {
            config.buffer_size = atoll(argv[i] + 7) * 1024 * 1024;
        } else if (strncmp(argv[i], "--iters=", 8) == 0) {
            config.num_iterations = atoi(argv[i] + 8);
        } else if (strncmp(argv[i], "--warmup=", 9) == 0) {
            config.warmup_iterations = atoi(argv[i] + 9);
        } else if (strcmp(argv[i], "--no-sync") == 0) {
            config.sync_each_iter = false;
        } else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
            printf("Usage: %s [options]\n", argv[0]);
            printf("Options:\n");
            printf("  --size=N     Buffer size in MB (default: 64)\n");
            printf("  --iters=N    Number of iterations (default: 1000)\n");
            printf("  --warmup=N   Warmup iterations (default: 100)\n");
            printf("  --no-sync    Don't sync between iterations\n");
            return 0;
        }
    }

    run_benchmark(num_gpus, config);
    return 0;
}

@hodgesds
Copy link
Author

hodgesds commented Feb 6, 2026

bpftrace script for tracking yield latency:

#!/usr/bin/env bpftrace

/*
 * yield_latency.bt - Measure time between sched_yield() and reschedule
 *
 * Tracks how long a task waits after calling yield() before being
 * scheduled back onto a CPU.
 *
 * Usage:
 *   sudo bpftrace yield_latency.bt              # no stacks
 *   sudo bpftrace yield_latency.bt 1            # with stacks
 */


/* Record timestamp and optionally stacks when task calls sched_yield */
tracepoint:syscalls:sys_enter_sched_yield
{
    @yield_start[tid] = nsecs;
    if ($1) {
        @yield_ustack[tid] = ustack(raw);
        @yield_kstack[tid] = kstack(raw);
    }
}

/* When task is scheduled back in, calculate latency */
tracepoint:sched:sched_switch
{
    $next_pid = args->next_pid;

    if (@yield_start[$next_pid]) {
        $latency_us = (nsecs - @yield_start[$next_pid]) / 1000;

        /* Output in Strobelight format for Scuba */
        if ($1) {
            $x = ("pid", args->next_pid,
                  "tid", $next_pid,
                  "comm", args->next_comm,
                  "ustack", @yield_ustack[$next_pid],
                  "kstack", @yield_kstack[$next_pid],
                  "latency_us", $latency_us);
            print($x);

            delete(@yield_ustack[$next_pid]);
            delete(@yield_kstack[$next_pid]);
        } else {
            $x = ("pid", args->next_pid,
                  "tid", $next_pid,
                  "comm", args->next_comm,
                  "latency_us", $latency_us);
            print($x);
        }

        @latency_hist = hist($latency_us);
        @total_yields = count();
        @avg_latency_us = avg($latency_us);
        @max_latency_us = max($latency_us);
        @min_latency_us = min($latency_us);

        delete(@yield_start[$next_pid]);
    }
}

END
{
    clear(@yield_start);
    clear(@yield_ustack);
    clear(@yield_kstack);
}

and output:

(pid, 573917, tid, 573917, comm, StatsAgg, latency_us, 29)
(pid, 577399, tid, 577399, comm, StatsAgg, latency_us, 9)
(pid, 573948, tid, 573948, comm, StatsAgg, latency_us, 232269)
^C

@avg_latency_us: 22089
@latency_hist:
[4, 8)                 1 |@@@                                                 |
[8, 16)               17 |@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@|
[16, 32)               8 |@@@@@@@@@@@@@@@@@@@@@@@@                            |
[32, 64)               5 |@@@@@@@@@@@@@@@                                     |
[64, 128)              4 |@@@@@@@@@@@@                                        |
[128, 256)             6 |@@@@@@@@@@@@@@@@@@                                  |
[256, 512)             0 |                                                    |
[512, 1K)              0 |                                                    |
[1K, 2K)               0 |                                                    |
[2K, 4K)               0 |                                                    |
[4K, 8K)               0 |                                                    |
[8K, 16K)              0 |                                                    |
[16K, 32K)             0 |                                                    |
[32K, 64K)             0 |                                                    |
[64K, 128K)            0 |                                                    |
[128K, 256K)           1 |@@@                                                 |
[256K, 512K)           2 |@@@@@@                                              |

@max_latency_us: 368859
@min_latency_us: 7
@total_yields: 44

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant