feat: Optimize SFT dataloader with slice-based tokenization and caching#1695
Conversation
a25e19d to
c937e47
Compare
Adds multiprocessing-based parallel tokenization with slice-based HF loading to eliminate pickle overhead. Includes tokenized dataset caching (pickle) with NFS support for multi-node training. New config options: num_workers, cache_dir, force_recache, disable_cache. Co-Authored-By: SumanthRH <sumanthrh@anyscale.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: jlee-lila <jlee@lila.ai>
c937e47 to
dc25e14
Compare
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
|
I've made some improvements:
|
There was a problem hiding this comment.
Code Review
This pull request implements parallel tokenization and on-disk caching for SFT datasets to optimize training startup. It introduces configuration settings for worker counts and cache paths, implements multiprocessing worker functions for data slicing, and adds logic to store tokenized datasets in an arrow-backed format. The PR also includes comprehensive tests for the new parallel processing and caching functionality. Feedback points out concurrency risks in the cache-saving process on shared filesystems and suggests a more robust approach for generating cache keys to prevent potential collisions.
| temp_path = cache_path + ".tmp" | ||
| # Clean up any stale temp dir from an interrupted prior run. | ||
| if os.path.isdir(temp_path): | ||
| shutil.rmtree(temp_path) | ||
| dataset.save_to_disk(temp_path) | ||
| # If a previous cache exists at the final path, drop it before | ||
| # rename so the swap is the only visible state change. | ||
| if os.path.isdir(cache_path): | ||
| shutil.rmtree(cache_path) | ||
| os.rename(temp_path, cache_path) |
There was a problem hiding this comment.
The current cache saving logic is not safe for multi-node training on shared filesystems (NFS). \n\n1. temp_path is not unique across processes or nodes, which can lead to data corruption if multiple workers attempt to write to the same cache simultaneously.\n2. The shutil.rmtree(temp_path) at the start of the write process can delete a directory that another process is currently using.\n3. The shutil.rmtree(cache_path) followed by os.rename is not atomic for directories on many systems, creating a race condition where one process might delete the successful output of another.\n\nA safer approach is to use a unique temporary directory and an atomic os.rename, ensuring that concurrent writes do not interfere with each other.
temp_path = f"{cache_path}.tmp.{random.getrandbits(64):x}"\n try:\n dataset.save_to_disk(temp_path)\n if os.path.isdir(cache_path):\n shutil.rmtree(cache_path, ignore_errors=True)\n os.rename(temp_path, cache_path)\n finally:\n if os.path.isdir(temp_path):\n shutil.rmtree(temp_path, ignore_errors=True)There was a problem hiding this comment.
this is a rare scenario with concurrent training runs for the same model and dataset name (i.e same cache key).
feat: Optimize SFT dataloader with slice-based tokenization and caching
Summary
Optimizes the SFT dataloader with two major improvements:
Performance Improvements
Slice-Based Parallel Tokenization
Key benefits:
Tokenized Dataset Caching
Key benefits:
cache_dir,force_recache,disable_cacheChanges
1. Slice-Based Tokenization (
skyrl/train/sft_trainer.py)_tokenize_chat_slice_worker()- worker for chat format with slice loading_tokenize_alpaca_slice_worker()- worker for Alpaca format with slice loading_parse_dataset_split()- parses split strings like"train[:100000]"into base split + indices_load_and_tokenize()- uses slice-based loading whennum_workers > 0How it works:
2. Dataset Caching (
skyrl/train/sft_trainer.py,skyrl/train/config/sft_config.py)_compute_cache_key()- deterministic hash of dataset + tokenization params_get_cache_path(),_load_from_cache(),_save_to_cache()- cache I/O_load_and_tokenize()- checks cache before tokenizingcache_dir,force_recache,disable_cacheCache key includes:
NFS-safe atomic writes:
Configuration
Slice-Based Tokenization
Caching
Use Cases
Testing
Tested with:
Test scripts available in commit history.
Breaking Changes
None - all changes are backward compatible. Default behavior unchanged.
Notes
Related Issues
Addresses performance concerns with SFT dataset loading at scale.