Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/distributed/distributed_dataloader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <algorithm>
#include <cstddef>
#include <memory>
#include <random>
#include <stdexcept>
#include <string>
#include <vector>
Expand Down Expand Up @@ -116,15 +117,13 @@ class DistributedSampler {
indices.push_back(i);
}

// Shuffle if requested (deterministically based on epoch)
// Shuffle using Fisher-Yates with deterministic PRNG seeded by (seed + epoch)
if (shuffle_) {
// Simple deterministic shuffle using epoch and seed
// In production, would use std::mt19937 with (seed_ + epoch)
size_t shuffle_seed = seed_ + epoch;
std::mt19937 rng(static_cast<uint32_t>(seed_ + epoch));
for (size_t i = indices.size() - 1; i > 0; --i) {
size_t j = (shuffle_seed * (i + 1)) % (i + 1);
std::uniform_int_distribution<size_t> dist(0, i);
size_t j = dist(rng);
std::swap(indices[i], indices[j]);
shuffle_seed = (shuffle_seed * 1103515245 + 12345) & 0x7FFFFFFF;
}
}

Expand Down
19 changes: 12 additions & 7 deletions src/distributed/sharding_strategies.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,20 @@ class HashBasedSharding : public ShardingStrategy {

/**
* @brief Get the hash value for a sample index
*
* Uses a strong mixing function (splitmix64-based) to ensure
* uniform distribution across ranks, avoiding the bias that
* simpler hashes produce with modulo assignment.
*/
size_t hash_index(size_t idx) const {
// FNV-1a hash variant
uint64_t hash = seed_;
hash ^= idx;
hash *= 0x100000001b3ULL;
hash ^= (idx >> 32);
hash *= 0x100000001b3ULL;
return hash % world_size_;
// splitmix64 mixing function for high-quality uniform distribution
uint64_t x = seed_ ^ idx;
x ^= x >> 30;
x *= 0xbf58476d1ce4e5b9ULL;
x ^= x >> 27;
x *= 0x94d049bb133111ebULL;
x ^= x >> 31;
return x % world_size_;
}

private:
Expand Down
19 changes: 18 additions & 1 deletion src/pipeline/checkpointing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,21 @@ class StateTracker {

/**
* @brief Set shuffle order for current epoch
*
* For small datasets (<10M samples), stores the full permutation.
* For large datasets, the shuffle order is regenerated from (rng_seed, epoch)
* during resume, which is always stored in the checkpoint state.
* This ensures reproducibility regardless of dataset size.
*/
void set_shuffle_order(const std::vector<size_t>& order) {
std::lock_guard<std::mutex> lock(mutex_);
// Only store if reasonably small (< 10M samples)
if (order.size() < 10000000) {
state_.shuffle_order = order;
state_.shuffle_order_stored = true;
} else {
// For large datasets, rely on seed + epoch for regeneration.
// The seed (rng_seed) and epoch are always stored in the state,
// so the shuffle order can be deterministically regenerated.
state_.shuffle_order.clear();
state_.shuffle_order_stored = false;
}
Expand Down Expand Up @@ -751,6 +758,11 @@ class StatefulDataLoader {

/**
* @brief Load state dict for resumption
*
* Restores exact iteration position. For shuffled datasets:
* - If shuffle order was stored (small datasets), restores it directly.
* - If not stored (large datasets >10M), regenerates from seed+epoch
* to produce the identical permutation.
*/
void load_state_dict(const PipelineState& state) {
tracker_.load_state_dict(state);
Expand All @@ -759,7 +771,12 @@ class StatefulDataLoader {
if (index_gen_) {
index_gen_->set_epoch(state.epoch);
if (state.shuffle_order_stored && !state.shuffle_order.empty()) {
// Small dataset: restore stored order directly
index_gen_->set_indices(state.shuffle_order);
} else if (state.shuffled) {
// Large dataset: regenerate from seed+epoch (deterministic)
index_gen_->set_shuffle(true);
index_gen_->set_epoch(state.epoch);
}
index_gen_->skip_to(state.samples_processed);
}
Expand Down
77 changes: 77 additions & 0 deletions tests/test_sharding_strategies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <gtest/gtest.h>
#include "../src/distributed/sharding_strategies.hpp"
#include "../src/distributed/distributed_dataloader.hpp"
#include <algorithm>
#include <numeric>
#include <set>
Expand Down Expand Up @@ -352,6 +353,82 @@ TEST(ShardingStringTest, InvalidString) {
EXPECT_THROW(string_to_sharding_type("INVALID"), std::invalid_argument);
}

// ============================================================================
// Hash Distribution Balance Tests
// ============================================================================

TEST(HashBalanceTest, EvenDistribution) {
// With splitmix64 hash, distribution across ranks should be near-uniform
size_t total = 10000;
size_t world_size = 8;

std::vector<size_t> counts(world_size, 0);
for (size_t rank = 0; rank < world_size; ++rank) {
HashBasedSharding shard(total, world_size, rank);
counts[rank] = shard.size();
}

// Each rank should get roughly total/world_size = 1250 samples
// Allow ±5% deviation (1187 to 1312)
size_t expected = total / world_size;
for (size_t rank = 0; rank < world_size; ++rank) {
EXPECT_GT(counts[rank], expected * 95 / 100)
<< "Rank " << rank << " has too few samples: " << counts[rank];
EXPECT_LT(counts[rank], expected * 105 / 100)
<< "Rank " << rank << " has too many samples: " << counts[rank];
}
}

// ============================================================================
// Distributed Sampler Shuffle Tests
// ============================================================================

TEST(DistributedShuffleTest, FisherYatesProducesPermutation) {
// Verify the shuffle produces a valid permutation (no duplicates, all present)
turboloader::distributed::DistributedConfig config;
config.world_size = 1;
config.rank = 0;
config.shuffle = true;
config.seed = 42;

turboloader::distributed::DistributedSampler sampler(100, config);
auto indices = sampler.get_indices(0);

// Should have all 100 indices
EXPECT_EQ(indices.size(), 100);

std::set<size_t> unique_indices(indices.begin(), indices.end());
EXPECT_EQ(unique_indices.size(), 100) << "Shuffle produced duplicates!";
}

TEST(DistributedShuffleTest, DifferentEpochsDifferentOrder) {
turboloader::distributed::DistributedConfig config;
config.world_size = 1;
config.rank = 0;
config.shuffle = true;
config.seed = 42;

turboloader::distributed::DistributedSampler sampler(100, config);
auto epoch0 = sampler.get_indices(0);
auto epoch1 = sampler.get_indices(1);

EXPECT_NE(epoch0, epoch1) << "Same order for different epochs";
}

TEST(DistributedShuffleTest, SameEpochSameOrder) {
turboloader::distributed::DistributedConfig config;
config.world_size = 1;
config.rank = 0;
config.shuffle = true;
config.seed = 42;

turboloader::distributed::DistributedSampler sampler(100, config);
auto run1 = sampler.get_indices(5);
auto run2 = sampler.get_indices(5);

EXPECT_EQ(run1, run2) << "Same epoch+seed should produce same order";
}

// ============================================================================
// Edge Cases
// ============================================================================
Expand Down