feat: add max_training_steps config for CI smoke tests#1690
Open
dinhxuanvu wants to merge 4 commits into
Open
Conversation
Add an optional max_training_steps parameter to TrainerConfig (RL) and SFTConfig (SFT) that caps training early regardless of epochs or dataset size. This enables CI smoke tests to exercise the full pipeline (data loading, model init, forward/backward pass) while limiting wall-clock time to a few steps. Applied consistently across all trainer paths: - Sync RL trainer (RayPPOTrainer) - Fully async RL trainer (FullyAsyncRayPPOTrainer) - SFT trainer (SFTTrainer) Unlike dummy_run_max_steps (which fabricates synthetic data for benchmarking), max_training_steps runs the real data path end-to-end and simply exits early — making it suitable for config validation.
Both RL (validate_cfg) and SFT (validate_sft_cfg) now raise when max_training_steps is set to 0 or a negative value, preventing nonsensical behavior like negative num_steps or global_step.
Change assert to ValueError in validate_cfg for consistency with SFT side and safety against -O execution. Add test coverage for RL validate_cfg rejecting zero/negative max_training_steps.
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces a max_training_steps configuration option across RL (sync/async) and SFT trainers, enabling early termination of training regardless of the number of epochs or dataset size. The changes include updates to configuration classes, validation logic, and trainer implementations, along with a new test suite. Review feedback identifies an off-by-one error in the early exit logic of both the synchronous and fully asynchronous trainers, where global_step is incremented before the limit check; this could lead to incorrect checkpointing and skipped steps upon resumption.
global_step is incremented before the max_training_steps check, so when the loop breaks it is max_training_steps + 1. Clamp it back to max_training_steps before exiting to ensure checkpoint metadata records the correct final step.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a
max_training_stepsconfig field that caps training after N steps regardless of epochs or dataset size. This enables fast CI smoke tests that validate full training pipelines end-to-end without running to completion.num_stepsafter resolution fromnum_epochsand propagates to the LR schedulertotal_training_stepsat initialization and adds an early exit after step completionMotivation
Running full training configs in nightly CI is prohibitively expensive.
max_training_steps=2exercises the core pipeline — data loading, model init, forward/backward, and logging — in minutes rather than hours/days, validating that the full stack initializes and trains without errors. Unlikedummy_run_max_steps(SFT-only, uses fabricated data), this uses real data and works across RL and SFT configs.Test plan
transformersunavailability)max_training_steps=2(requires cluster)