From f9220f6c09e7fc18aa4584ccd09ddaad44bcb6ea Mon Sep 17 00:00:00 2001 From: ALJainProjects Date: Sun, 8 Feb 2026 19:03:37 -0500 Subject: [PATCH 1/2] fix: correct distributed training shuffle, hash bias, and checkpoint resume - Replace biased shuffle with proper Fisher-Yates using std::mt19937 - Upgrade hash_index() from weak FNV-1a to splitmix64 for uniform shard distribution - Fix checkpoint resume for large shuffled datasets via seed+epoch regeneration - Add tests for hash balance, shuffle permutation correctness, and epoch determinism --- src/distributed/distributed_dataloader.hpp | 11 ++-- src/distributed/sharding_strategies.hpp | 19 ++++-- src/pipeline/checkpointing.hpp | 19 +++++- tests/test_sharding_strategies.cpp | 76 ++++++++++++++++++++++ 4 files changed, 111 insertions(+), 14 deletions(-) diff --git a/src/distributed/distributed_dataloader.hpp b/src/distributed/distributed_dataloader.hpp index 3c9ae4e..4110505 100644 --- a/src/distributed/distributed_dataloader.hpp +++ b/src/distributed/distributed_dataloader.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -116,15 +117,13 @@ class DistributedSampler { indices.push_back(i); } - // Shuffle if requested (deterministically based on epoch) + // Shuffle using Fisher-Yates with deterministic PRNG seeded by (seed + epoch) if (shuffle_) { - // Simple deterministic shuffle using epoch and seed - // In production, would use std::mt19937 with (seed_ + epoch) - size_t shuffle_seed = seed_ + epoch; + std::mt19937 rng(static_cast(seed_ + epoch)); for (size_t i = indices.size() - 1; i > 0; --i) { - size_t j = (shuffle_seed * (i + 1)) % (i + 1); + std::uniform_int_distribution dist(0, i); + size_t j = dist(rng); std::swap(indices[i], indices[j]); - shuffle_seed = (shuffle_seed * 1103515245 + 12345) & 0x7FFFFFFF; } } diff --git a/src/distributed/sharding_strategies.hpp b/src/distributed/sharding_strategies.hpp index 5f9754f..f390cd1 100644 --- a/src/distributed/sharding_strategies.hpp +++ b/src/distributed/sharding_strategies.hpp @@ -294,15 +294,20 @@ class HashBasedSharding : public ShardingStrategy { /** * @brief Get the hash value for a sample index + * + * Uses a strong mixing function (splitmix64-based) to ensure + * uniform distribution across ranks, avoiding the bias that + * simpler hashes produce with modulo assignment. */ size_t hash_index(size_t idx) const { - // FNV-1a hash variant - uint64_t hash = seed_; - hash ^= idx; - hash *= 0x100000001b3ULL; - hash ^= (idx >> 32); - hash *= 0x100000001b3ULL; - return hash % world_size_; + // splitmix64 mixing function for high-quality uniform distribution + uint64_t x = seed_ ^ idx; + x ^= x >> 30; + x *= 0xbf58476d1ce4e5b9ULL; + x ^= x >> 27; + x *= 0x94d049bb133111ebULL; + x ^= x >> 31; + return x % world_size_; } private: diff --git a/src/pipeline/checkpointing.hpp b/src/pipeline/checkpointing.hpp index 88a6613..d406251 100644 --- a/src/pipeline/checkpointing.hpp +++ b/src/pipeline/checkpointing.hpp @@ -410,14 +410,21 @@ class StateTracker { /** * @brief Set shuffle order for current epoch + * + * For small datasets (<10M samples), stores the full permutation. + * For large datasets, the shuffle order is regenerated from (rng_seed, epoch) + * during resume, which is always stored in the checkpoint state. + * This ensures reproducibility regardless of dataset size. */ void set_shuffle_order(const std::vector& order) { std::lock_guard lock(mutex_); - // Only store if reasonably small (< 10M samples) if (order.size() < 10000000) { state_.shuffle_order = order; state_.shuffle_order_stored = true; } else { + // For large datasets, rely on seed + epoch for regeneration. + // The seed (rng_seed) and epoch are always stored in the state, + // so the shuffle order can be deterministically regenerated. state_.shuffle_order.clear(); state_.shuffle_order_stored = false; } @@ -751,6 +758,11 @@ class StatefulDataLoader { /** * @brief Load state dict for resumption + * + * Restores exact iteration position. For shuffled datasets: + * - If shuffle order was stored (small datasets), restores it directly. + * - If not stored (large datasets >10M), regenerates from seed+epoch + * to produce the identical permutation. */ void load_state_dict(const PipelineState& state) { tracker_.load_state_dict(state); @@ -759,7 +771,12 @@ class StatefulDataLoader { if (index_gen_) { index_gen_->set_epoch(state.epoch); if (state.shuffle_order_stored && !state.shuffle_order.empty()) { + // Small dataset: restore stored order directly index_gen_->set_indices(state.shuffle_order); + } else if (state.shuffled) { + // Large dataset: regenerate from seed+epoch (deterministic) + index_gen_->set_shuffle(true); + index_gen_->set_epoch(state.epoch); } index_gen_->skip_to(state.samples_processed); } diff --git a/tests/test_sharding_strategies.cpp b/tests/test_sharding_strategies.cpp index ceadcd6..7f01949 100644 --- a/tests/test_sharding_strategies.cpp +++ b/tests/test_sharding_strategies.cpp @@ -352,6 +352,82 @@ TEST(ShardingStringTest, InvalidString) { EXPECT_THROW(string_to_sharding_type("INVALID"), std::invalid_argument); } +// ============================================================================ +// Hash Distribution Balance Tests +// ============================================================================ + +TEST(HashBalanceTest, EvenDistribution) { + // With splitmix64 hash, distribution across ranks should be near-uniform + size_t total = 10000; + size_t world_size = 8; + + std::vector counts(world_size, 0); + for (size_t rank = 0; rank < world_size; ++rank) { + HashBasedSharding shard(total, world_size, rank); + counts[rank] = shard.size(); + } + + // Each rank should get roughly total/world_size = 1250 samples + // Allow ±5% deviation (1187 to 1312) + size_t expected = total / world_size; + for (size_t rank = 0; rank < world_size; ++rank) { + EXPECT_GT(counts[rank], expected * 95 / 100) + << "Rank " << rank << " has too few samples: " << counts[rank]; + EXPECT_LT(counts[rank], expected * 105 / 100) + << "Rank " << rank << " has too many samples: " << counts[rank]; + } +} + +// ============================================================================ +// Distributed Sampler Shuffle Tests +// ============================================================================ + +TEST(DistributedShuffleTest, FisherYatesProducesPermutation) { + // Verify the shuffle produces a valid permutation (no duplicates, all present) + turboloader::distributed::DistributedConfig config; + config.world_size = 1; + config.rank = 0; + config.shuffle = true; + config.seed = 42; + + turboloader::distributed::DistributedSampler sampler(100, config); + auto indices = sampler.get_indices(0); + + // Should have all 100 indices + EXPECT_EQ(indices.size(), 100); + + std::set unique_indices(indices.begin(), indices.end()); + EXPECT_EQ(unique_indices.size(), 100) << "Shuffle produced duplicates!"; +} + +TEST(DistributedShuffleTest, DifferentEpochsDifferentOrder) { + turboloader::distributed::DistributedConfig config; + config.world_size = 1; + config.rank = 0; + config.shuffle = true; + config.seed = 42; + + turboloader::distributed::DistributedSampler sampler(100, config); + auto epoch0 = sampler.get_indices(0); + auto epoch1 = sampler.get_indices(1); + + EXPECT_NE(epoch0, epoch1) << "Same order for different epochs"; +} + +TEST(DistributedShuffleTest, SameEpochSameOrder) { + turboloader::distributed::DistributedConfig config; + config.world_size = 1; + config.rank = 0; + config.shuffle = true; + config.seed = 42; + + turboloader::distributed::DistributedSampler sampler(100, config); + auto run1 = sampler.get_indices(5); + auto run2 = sampler.get_indices(5); + + EXPECT_EQ(run1, run2) << "Same epoch+seed should produce same order"; +} + // ============================================================================ // Edge Cases // ============================================================================ From fcb9a6002a6527d3044e9c093d164bf1e0484777 Mon Sep 17 00:00:00 2001 From: ALJainProjects Date: Sun, 8 Feb 2026 19:38:18 -0500 Subject: [PATCH 2/2] Fix missing include for DistributedConfig/DistributedSampler in test The shuffle tests reference turboloader::distributed::DistributedConfig and DistributedSampler which live in distributed_dataloader.hpp, not sharding_strategies.hpp. --- tests/test_sharding_strategies.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_sharding_strategies.cpp b/tests/test_sharding_strategies.cpp index 7f01949..b4b7956 100644 --- a/tests/test_sharding_strategies.cpp +++ b/tests/test_sharding_strategies.cpp @@ -5,6 +5,7 @@ #include #include "../src/distributed/sharding_strategies.hpp" +#include "../src/distributed/distributed_dataloader.hpp" #include #include #include