From 451638b4a7ca812fb8ca0f9eb7544352c982ee95 Mon Sep 17 00:00:00 2001 From: Vu Dinh Date: Tue, 19 May 2026 10:59:50 -0400 Subject: [PATCH 1/4] feat: add max_training_steps config for CI smoke tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add an optional max_training_steps parameter to TrainerConfig (RL) and SFTConfig (SFT) that caps training early regardless of epochs or dataset size. This enables CI smoke tests to exercise the full pipeline (data loading, model init, forward/backward pass) while limiting wall-clock time to a few steps. Applied consistently across all trainer paths: - Sync RL trainer (RayPPOTrainer) - Fully async RL trainer (FullyAsyncRayPPOTrainer) - SFT trainer (SFTTrainer) Unlike dummy_run_max_steps (which fabricates synthetic data for benchmarking), max_training_steps runs the real data path end-to-end and simply exits early — making it suitable for config validation. --- skyrl/train/config/config.py | 3 + skyrl/train/config/sft_config.py | 5 + skyrl/train/fully_async_trainer.py | 17 ++ skyrl/train/sft_trainer.py | 14 +- skyrl/train/trainer.py | 14 + tests/train/test_max_training_steps.py | 386 +++++++++++++++++++++++++ 6 files changed, 437 insertions(+), 2 deletions(-) create mode 100644 tests/train/test_max_training_steps.py diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 04746ea87d..c6ec861d4d 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -624,6 +624,9 @@ class TrainerConfig(BaseConfig): """Path for exported artifacts (HF models, debug dumps, etc.).""" bf16: bool = True epochs: int = 1 + max_training_steps: Optional[int] = None + """If set, stop training after this many steps regardless of epochs or dataset size. + Useful for CI smoke tests and quick validation runs.""" update_epochs_per_batch: int = 1 """Number of gradient update passes over each training batch.""" train_batch_size: int = 1024 diff --git a/skyrl/train/config/sft_config.py b/skyrl/train/config/sft_config.py index 6264dd80bf..3cc3fd4000 100644 --- a/skyrl/train/config/sft_config.py +++ b/skyrl/train/config/sft_config.py @@ -184,6 +184,11 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig": dummy_run_full_ctx: bool = False # Skip real data; fabricate full-context sequences dummy_run_max_steps: int = 5 # Number of steps to run in dummy mode + # ---- CI / smoke test support ---- + max_training_steps: Optional[int] = None + """If set, stop training after this many steps regardless of num_steps or num_epochs. + Useful for CI smoke tests and quick validation runs.""" + # --------------------------------------------------------------------------- # Bridge: SFTConfig -> SkyRLTrainConfig diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index 23078c7c1c..b7e8e70670 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -319,6 +319,8 @@ def _build_train_dataloader_and_compute_training_steps(self): self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True, is_fully_async=True) self.num_steps_per_epoch = len(self.train_dataloader) // self.mini_batch_size self.total_training_steps = self.num_steps_per_epoch * self.cfg.trainer.epochs + if self.cfg.trainer.max_training_steps is not None: + self.total_training_steps = min(self.total_training_steps, self.cfg.trainer.max_training_steps) logger.info(f"Length of train_dataloader: {len(self.train_dataloader)}") logger.info(f"Number of steps per epoch: {self.num_steps_per_epoch}") logger.info(f"Total training steps: {self.total_training_steps}") @@ -368,6 +370,7 @@ async def train(self): pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Step Progress") start_epoch = self.global_step // self.num_steps_per_epoch self.global_step += 1 # start training at global_step 1 + stop_training = False for epoch in range(start_epoch, self.cfg.trainer.epochs): # 0. Per-epoch prologue. Note that we do not do any cross-epoch asynchrony here. @@ -457,6 +460,17 @@ async def train(self): self.all_timings = {} self.global_step += 1 + if ( + self.cfg.trainer.max_training_steps is not None + and self.global_step > self.cfg.trainer.max_training_steps + ): + logger.info(f"Reached max_training_steps={self.cfg.trainer.max_training_steps}, stopping early.") + for t in generator_tasks: + t.cancel() + await asyncio.gather(*generator_tasks, return_exceptions=True) + stop_training = True + break + # 8. Notify generation workers that the capacity has increased, unblocking them. await self._staleness_manager.notify_capacity_change(self.global_step) steps_completed_in_epoch = (self.global_step - 1) % self.num_steps_per_epoch @@ -470,6 +484,9 @@ async def train(self): f"{actual_consumed_in_epoch} != {expected_consumed_in_epoch}" ) + if stop_training: + break + # 9. Per-epoch epilogue. if self.cfg.trainer.update_ref_every_epoch and self.ref_model is not None: with Timer("update_ref_with_policy", self.all_timings): diff --git a/skyrl/train/sft_trainer.py b/skyrl/train/sft_trainer.py index 52faa3acf5..27c411dcae 100644 --- a/skyrl/train/sft_trainer.py +++ b/skyrl/train/sft_trainer.py @@ -524,8 +524,14 @@ def _init_workers(self): num_training_steps = ( self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps ) - # num_steps may be None when num_epochs is used; the worker will use its - # default (large value) for the LR scheduler in that case. + if self.sft_cfg.max_training_steps is not None: + num_training_steps = ( + self.sft_cfg.max_training_steps + if num_training_steps is None + else min(num_training_steps, self.sft_cfg.max_training_steps) + ) + # num_steps may be None when num_epochs is used; without an explicit cap, + # the worker will use its default large value for the LR scheduler. ray.get( actor_group.async_init_model( self.sft_cfg.model.path, @@ -991,6 +997,10 @@ def train(self): f"ceil({len(tokenized)} / {batch_size}) * {self.sft_cfg.num_epochs} = {num_steps} steps" ) + if self.sft_cfg.max_training_steps is not None: + num_steps = min(num_steps, self.sft_cfg.max_training_steps) + logger.info(f"Capping training at max_training_steps={self.sft_cfg.max_training_steps}") + # Early validation: dataset must have at least batch_size examples if len(tokenized) < batch_size: raise ValueError( diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 99da9a915a..d39314aeaa 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -154,6 +154,8 @@ def _build_train_dataloader_and_compute_training_steps(self): if self.train_dataset is not None: self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True) self.total_training_steps = len(self.train_dataloader) * self.cfg.trainer.epochs + if self.cfg.trainer.max_training_steps is not None: + self.total_training_steps = min(self.total_training_steps, self.cfg.trainer.max_training_steps) @torch.no_grad() async def eval(self) -> Dict[str, float]: @@ -215,6 +217,7 @@ async def train(self): pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed") start_epoch = self.global_step // len(self.train_dataloader) self.global_step += 1 # start training at global_step 1 + stop_training = False for epoch in range(start_epoch, self.cfg.trainer.epochs): for _, rand_prompts in enumerate(self.train_dataloader): with Timer("step", self.all_timings): @@ -353,8 +356,19 @@ async def train(self): self.global_step += 1 + if ( + self.cfg.trainer.max_training_steps is not None + and self.global_step > self.cfg.trainer.max_training_steps + ): + logger.info(f"Reached max_training_steps={self.cfg.trainer.max_training_steps}, stopping early.") + stop_training = True + break + del training_input, generator_output + if stop_training: + break + pbar.close() if self.colocate_all: await self.inference_engine_client.sleep() diff --git a/tests/train/test_max_training_steps.py b/tests/train/test_max_training_steps.py new file mode 100644 index 0000000000..967b8a136a --- /dev/null +++ b/tests/train/test_max_training_steps.py @@ -0,0 +1,386 @@ +""" +Tests for the max_training_steps feature across all trainer paths. + +Verifies that setting max_training_steps correctly caps training +regardless of epochs or dataset size, for RL (sync/async) and SFT trainers. + +uv run --extra dev --extra skyrl-train pytest tests/train/test_max_training_steps.py -v +""" + +import importlib.util +import sys +from math import ceil +from types import ModuleType, SimpleNamespace + +import pytest +from omegaconf import OmegaConf + +from skyrl.train.config.config import SkyRLTrainConfig, TrainerConfig +from skyrl.train.config.sft_config import SFTConfig, validate_sft_cfg + +requires_transformers = pytest.mark.skipif( + importlib.util.find_spec("transformers") is None, + reason="transformers not available (Linux-only dependency)", +) + + +class _FakeTransformers(ModuleType): + def __getattr__(self, name): + value = type(name, (), {}) + setattr(self, name, value) + return value + + +def _install_fake_transformers(monkeypatch): + module = _FakeTransformers("transformers") + module.__file__ = "fake_transformers.py" + module.__path__ = [] + monkeypatch.setitem(sys.modules, "transformers", module) + flash_attention_utils = ModuleType("transformers.modeling_flash_attention_utils") + flash_attention_utils._flash_attention_forward = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "transformers.modeling_flash_attention_utils", flash_attention_utils) + + +def _install_lightweight_worker_modules(monkeypatch): + worker_module = ModuleType("skyrl.backends.skyrl_train.workers.worker") + worker_module.PPORayActorGroup = object + monkeypatch.setitem(sys.modules, "skyrl.backends.skyrl_train.workers.worker", worker_module) + + worker_dispatch_module = ModuleType("skyrl.backends.skyrl_train.workers.worker_dispatch") + worker_dispatch_module.WorkerDispatch = object + monkeypatch.setitem(sys.modules, "skyrl.backends.skyrl_train.workers.worker_dispatch", worker_dispatch_module) + + +# --------------------------------------------------------------------------- +# Config-level tests: TrainerConfig (RL) +# --------------------------------------------------------------------------- + + +class TestTrainerConfigMaxSteps: + """TrainerConfig.max_training_steps field behavior.""" + + def test_default_is_none(self): + cfg = TrainerConfig() + assert cfg.max_training_steps is None + + def test_set_via_constructor(self): + cfg = TrainerConfig(max_training_steps=10) + assert cfg.max_training_steps == 10 + + @requires_transformers + def test_set_via_from_dict_config(self): + cfg_dict = OmegaConf.create({"trainer": {"max_training_steps": 5}}) + cfg = SkyRLTrainConfig.from_dict_config(cfg_dict) + assert cfg.trainer.max_training_steps == 5 + + @requires_transformers + def test_cli_override(self): + cfg = SkyRLTrainConfig.from_cli_overrides(["trainer.max_training_steps=3"]) + assert cfg.trainer.max_training_steps == 3 + + @requires_transformers + def test_none_preserved_when_unset(self): + cfg_dict = OmegaConf.create({"trainer": {"epochs": 5}}) + cfg = SkyRLTrainConfig.from_dict_config(cfg_dict) + assert cfg.trainer.max_training_steps is None + + +# --------------------------------------------------------------------------- +# Config-level tests: SFTConfig +# --------------------------------------------------------------------------- + + +class TestSFTConfigMaxSteps: + """SFTConfig.max_training_steps field behavior.""" + + def test_default_is_none(self): + cfg = SFTConfig() + assert cfg.max_training_steps is None + + def test_set_via_constructor(self): + cfg = SFTConfig(max_training_steps=7) + assert cfg.max_training_steps == 7 + + def test_set_via_cli_overrides(self): + cfg = SFTConfig.from_cli_overrides(["max_training_steps=10"]) + assert cfg.max_training_steps == 10 + + def test_set_via_dict_overrides(self): + cfg = SFTConfig.from_cli_overrides({"max_training_steps": 12}) + assert cfg.max_training_steps == 12 + + def test_validation_passes_with_max_training_steps(self): + cfg = SFTConfig(max_training_steps=3) + validate_sft_cfg(cfg) + + def test_validation_passes_without_max_training_steps(self): + cfg = SFTConfig() + validate_sft_cfg(cfg) + + +# --------------------------------------------------------------------------- +# RL Trainer: total_training_steps capping logic (shared by sync and async) +# --------------------------------------------------------------------------- + + +class TestRLTrainerStepsCapping: + """Verify capping logic in _build_train_dataloader_and_compute_training_steps. + + Both sync and async RL trainers use the same pattern: + total = dataloader_steps * epochs + if max_training_steps: total = min(total, max_training_steps) + """ + + @pytest.mark.parametrize( + "dl_len,epochs,max_steps,expected", + [ + (50, 2, None, 100), + (50, 2, 10, 10), + (50, 2, 1000, 100), + (50, 2, 100, 100), + (50, 2, 1, 1), + ], + ids=["no_cap", "caps_smaller", "no_effect_larger", "boundary_equal", "single_step"], + ) + def test_capping(self, dl_len, epochs, max_steps, expected): + total = dl_len * epochs + if max_steps is not None: + total = min(total, max_steps) + assert total == expected + + +# --------------------------------------------------------------------------- +# RL Trainer: early exit condition (shared by sync and async) +# --------------------------------------------------------------------------- + + +class TestRLTrainerEarlyExit: + """Verify early exit condition: global_step > max_training_steps. + + Both sync and async trainers use the same check after incrementing global_step. + """ + + @pytest.mark.parametrize( + "global_step,max_steps,should_exit", + [ + (6, 5, True), + (5, 5, False), + (3, 5, False), + (9999, None, False), + (1, 1, False), + (2, 1, True), + ], + ids=["exceeded", "at_boundary", "below", "none_never_exits", "boundary_1", "exit_after_1"], + ) + def test_exit_condition(self, global_step, max_steps, should_exit): + exits = max_steps is not None and global_step > max_steps + assert exits is should_exit + + +# --------------------------------------------------------------------------- +# SFT Trainer: num_steps capping logic +# --------------------------------------------------------------------------- + + +class TestSFTTrainerMaxStepsCapping: + """Verify the capping logic used in SFTTrainer.train().""" + + @staticmethod + def _resolve_num_steps(num_steps=None, num_epochs=None, dataset_len=100, batch_size=4, max_training_steps=None): + if num_steps is not None: + resolved = num_steps + else: + resolved = ceil(dataset_len / batch_size) * num_epochs + if max_training_steps is not None: + resolved = min(resolved, max_training_steps) + return resolved + + @pytest.mark.parametrize( + "kwargs,expected", + [ + (dict(num_epochs=10, dataset_len=100, batch_size=4, max_training_steps=3), 3), + (dict(num_steps=50, max_training_steps=5), 5), + (dict(num_steps=50, max_training_steps=None), 50), + (dict(num_steps=10, max_training_steps=1000), 10), + (dict(num_epochs=2, dataset_len=100, batch_size=4, max_training_steps=None), 50), + (dict(num_epochs=100, dataset_len=1000, batch_size=1, max_training_steps=1), 1), + ], + ids=[ + "caps_epoch_derived", + "caps_explicit_num_steps", + "no_cap_when_none", + "no_effect_larger", + "epoch_resolution_no_cap", + "single_step", + ], + ) + def test_capping(self, kwargs, expected): + assert self._resolve_num_steps(**kwargs) == expected + + def test_worker_initialization_uses_capped_num_training_steps(self, monkeypatch): + _install_fake_transformers(monkeypatch) + _install_lightweight_worker_modules(monkeypatch) + from skyrl.train import sft_trainer as sft_module + + captured = {} + + class FakeActorGroup: + def __init__(self, *args, **kwargs): + pass + + def async_init_model(self, model_path, num_training_steps=None): + captured["model_path"] = model_path + captured["num_training_steps"] = num_training_steps + return None + + def async_run_ray_method(self, *args, **kwargs): + return None + + fake_worker_module = SimpleNamespace(PolicyWorker=object) + monkeypatch.setitem( + sys.modules, + "skyrl.backends.skyrl_train.workers.megatron.megatron_worker", + fake_worker_module, + ) + monkeypatch.setattr(sft_module, "placement_group", lambda *args, **kwargs: object()) + monkeypatch.setattr(sft_module, "get_ray_pg_ready_with_timeout", lambda *args, **kwargs: None) + monkeypatch.setattr(sft_module, "ResolvedPlacementGroup", lambda pg: pg) + monkeypatch.setattr(sft_module, "PPORayActorGroup", FakeActorGroup) + monkeypatch.setattr(sft_module, "WorkerDispatch", lambda *args, **kwargs: object()) + monkeypatch.setattr(sft_module.ray, "get", lambda result: result) + + trainer = object.__new__(sft_module.SFTTrainer) + trainer.sft_cfg = SFTConfig(num_steps=1000, max_training_steps=5) + trainer.sft_cfg.placement.num_gpus_per_node = 1 + trainer.cfg = SimpleNamespace( + trainer=SimpleNamespace( + policy=SimpleNamespace(sequence_parallel_size=1, record_memory=False), + ) + ) + trainer.tokenizer = SimpleNamespace(pad_token_id=0) + + trainer._init_workers() + + assert captured["num_training_steps"] == 5 + + +class TestRLTrainerFinalization: + @pytest.mark.asyncio + async def test_sync_trainer_max_steps_runs_finalization(self, monkeypatch): + _install_fake_transformers(monkeypatch) + _install_lightweight_worker_modules(monkeypatch) + from skyrl.train import trainer as trainer_module + + events = [] + + class FakeTrainingInput(dict): + def __init__(self): + super().__init__({"rewards": [1.0]}) + self.metadata = {"uids": ["uid-0"]} + + class FakeTracker: + def log(self, *args, **kwargs): + pass + + def finish(self): + events.append("tracker_finish") + + class FakeDispatch: + async def save_weights_for_sampler(self): + pass + + class FakeInferenceEngineClient: + async def sleep(self): + events.append("sleep") + + async def fake_generate(generator_input): + return {"response_ids": [[1]], "rewards": [1.0]} + + monkeypatch.setattr( + trainer_module, + "prepare_generator_input", + lambda *args, **kwargs: ({"prompts": ["prompt"]}, ["uid-0"]), + ) + monkeypatch.setattr(trainer_module, "get_sampling_params_for_backend", lambda *args, **kwargs: {}) + + trainer = object.__new__(trainer_module.RayPPOTrainer) + trainer.cfg = SimpleNamespace( + trainer=SimpleNamespace( + epochs=2, + max_training_steps=1, + eval_interval=0, + eval_before_train=False, + algorithm=SimpleNamespace(use_kl_in_reward=False, dynamic_sampling=SimpleNamespace(type=None)), + log_example_interval=0, + dump_data_batch=False, + ckpt_interval=1, + hf_save_interval=1, + update_ref_every_epoch=False, + ), + generator=SimpleNamespace( + n_samples_per_prompt=1, + step_wise_trajectories=False, + inference_engine=SimpleNamespace(backend="vllm"), + sampling_params=SimpleNamespace(), + ), + environment=SimpleNamespace(env_class="gsm8k"), + ) + trainer.colocate_all = True + trainer.tracker = FakeTracker() + trainer.tokenizer = SimpleNamespace(decode=lambda ids: "decoded") + trainer.train_dataloader = [["prompt-0"], ["prompt-1"]] + trainer.total_training_steps = 1 + trainer.resume_mode = trainer_module.ResumeMode.NONE + trainer.all_metrics = {} + trainer.all_timings = {} + trainer.global_step = 0 + trainer._vllm_metrics_scraper = None + trainer.dispatch = FakeDispatch() + trainer.inference_engine_client = FakeInferenceEngineClient() + trainer.ref_model = None + trainer.init_weight_sync_state = lambda: None + trainer._remove_tail_data = lambda rand_prompts: rand_prompts + trainer.generate = fake_generate + trainer.postprocess_generator_output = lambda generator_output, uids: (generator_output, uids) + trainer.convert_to_training_input = lambda generator_output, uids: FakeTrainingInput() + trainer.fwd_logprobs_values_reward = lambda training_input: training_input + trainer.compute_advantages_and_returns = lambda training_input: training_input + trainer.train_critic_and_policy = lambda training_input: {"loss": 0.0} + trainer.save_checkpoints = lambda: events.append("save_checkpoints") + trainer.save_models = lambda: events.append("save_models") + + await trainer.train() + + assert events[-4:] == ["sleep", "save_checkpoints", "save_models", "tracker_finish"] + + +# --------------------------------------------------------------------------- +# Integration: config round-trips +# --------------------------------------------------------------------------- + + +class TestConfigRoundTrips: + """max_training_steps survives config construction paths.""" + + @requires_transformers + def test_rl_config_from_dict_config_roundtrip(self): + cfg_dict = OmegaConf.create({"trainer": {"max_training_steps": 42, "epochs": 5, "train_batch_size": 8}}) + cfg = SkyRLTrainConfig.from_dict_config(cfg_dict) + assert cfg.trainer.max_training_steps == 42 + assert cfg.trainer.epochs == 5 + + def test_sft_config_from_dict_overrides(self): + cfg = SFTConfig.from_cli_overrides({"max_training_steps": 7, "num_epochs": 3}) + assert cfg.max_training_steps == 7 + assert cfg.num_epochs == 3 + + @requires_transformers + def test_rl_config_max_steps_coexists_with_epochs(self): + cfg = SkyRLTrainConfig.from_cli_overrides(["trainer.max_training_steps=10", "trainer.epochs=100"]) + assert cfg.trainer.max_training_steps == 10 + assert cfg.trainer.epochs == 100 + + def test_sft_max_steps_coexists_with_num_steps(self): + cfg = SFTConfig.from_cli_overrides(["max_training_steps=5", "num_steps=100"]) + assert cfg.max_training_steps == 5 + assert cfg.num_steps == 100 From 3bec795319c26e2542f6cf5873751b2ab378677d Mon Sep 17 00:00:00 2001 From: Vu Dinh Date: Tue, 19 May 2026 12:51:33 -0400 Subject: [PATCH 2/4] fix: reject zero/negative max_training_steps in validation Both RL (validate_cfg) and SFT (validate_sft_cfg) now raise when max_training_steps is set to 0 or a negative value, preventing nonsensical behavior like negative num_steps or global_step. --- skyrl/train/config/sft_config.py | 2 ++ skyrl/train/utils/utils.py | 5 +++++ tests/train/test_max_training_steps.py | 10 ++++++++++ 3 files changed, 17 insertions(+) diff --git a/skyrl/train/config/sft_config.py b/skyrl/train/config/sft_config.py index 3cc3fd4000..6bead10a3b 100644 --- a/skyrl/train/config/sft_config.py +++ b/skyrl/train/config/sft_config.py @@ -230,6 +230,8 @@ def validate_sft_cfg(cfg: SFTConfig) -> None: raise ValueError("model.path must be set") if cfg.dummy_run_full_ctx and cfg.dummy_run_max_steps <= 0: raise ValueError(f"dummy_run_max_steps must be > 0, got {cfg.dummy_run_max_steps}") + if cfg.max_training_steps is not None and cfg.max_training_steps <= 0: + raise ValueError(f"max_training_steps must be > 0, got {cfg.max_training_steps}") # Eval config if cfg.eval_interval < 0: diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 4ac09c3eb3..7ce758c122 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -227,6 +227,11 @@ def validate_cfg(cfg: SkyRLTrainConfig): ) cfg.trainer.strategy = "fsdp" + if cfg.trainer.max_training_steps is not None: + assert ( + cfg.trainer.max_training_steps > 0 + ), f"max_training_steps must be > 0, got {cfg.trainer.max_training_steps}" + # Validate generation config separately validate_generator_cfg(cfg) from skyrl.backends.skyrl_train.utils.ppo_utils import ( diff --git a/tests/train/test_max_training_steps.py b/tests/train/test_max_training_steps.py index 967b8a136a..d19968f8ae 100644 --- a/tests/train/test_max_training_steps.py +++ b/tests/train/test_max_training_steps.py @@ -117,6 +117,16 @@ def test_validation_passes_without_max_training_steps(self): cfg = SFTConfig() validate_sft_cfg(cfg) + def test_validation_rejects_zero(self): + cfg = SFTConfig(max_training_steps=0) + with pytest.raises(ValueError, match="max_training_steps must be > 0"): + validate_sft_cfg(cfg) + + def test_validation_rejects_negative(self): + cfg = SFTConfig(max_training_steps=-1) + with pytest.raises(ValueError, match="max_training_steps must be > 0"): + validate_sft_cfg(cfg) + # --------------------------------------------------------------------------- # RL Trainer: total_training_steps capping logic (shared by sync and async) From ebc84a5b0bc88fc7ea312476aadf9f35112356d5 Mon Sep 17 00:00:00 2001 From: Vu Dinh Date: Tue, 19 May 2026 12:57:23 -0400 Subject: [PATCH 3/4] fix: use ValueError for RL max_training_steps validation + add tests Change assert to ValueError in validate_cfg for consistency with SFT side and safety against -O execution. Add test coverage for RL validate_cfg rejecting zero/negative max_training_steps. --- skyrl/train/utils/utils.py | 5 ++--- tests/train/test_max_training_steps.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 7ce758c122..68fccd3d31 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -228,9 +228,8 @@ def validate_cfg(cfg: SkyRLTrainConfig): cfg.trainer.strategy = "fsdp" if cfg.trainer.max_training_steps is not None: - assert ( - cfg.trainer.max_training_steps > 0 - ), f"max_training_steps must be > 0, got {cfg.trainer.max_training_steps}" + if cfg.trainer.max_training_steps <= 0: + raise ValueError(f"max_training_steps must be > 0, got {cfg.trainer.max_training_steps}") # Validate generation config separately validate_generator_cfg(cfg) diff --git a/tests/train/test_max_training_steps.py b/tests/train/test_max_training_steps.py index d19968f8ae..b55840dbd7 100644 --- a/tests/train/test_max_training_steps.py +++ b/tests/train/test_max_training_steps.py @@ -128,6 +128,24 @@ def test_validation_rejects_negative(self): validate_sft_cfg(cfg) +class TestRLValidateCfgMaxSteps: + """validate_cfg rejects max_training_steps <= 0 for RL configs.""" + + def test_rejects_zero(self): + from skyrl.train.utils.utils import validate_cfg + + cfg = SimpleNamespace(trainer=TrainerConfig(max_training_steps=0)) + with pytest.raises(ValueError, match="max_training_steps must be > 0"): + validate_cfg(cfg) + + def test_rejects_negative(self): + from skyrl.train.utils.utils import validate_cfg + + cfg = SimpleNamespace(trainer=TrainerConfig(max_training_steps=-5)) + with pytest.raises(ValueError, match="max_training_steps must be > 0"): + validate_cfg(cfg) + + # --------------------------------------------------------------------------- # RL Trainer: total_training_steps capping logic (shared by sync and async) # --------------------------------------------------------------------------- From 8fd72e2a8c1c672803196329b3b9238121789630 Mon Sep 17 00:00:00 2001 From: Vu Dinh Date: Tue, 19 May 2026 14:03:40 -0400 Subject: [PATCH 4/4] fix: correct global_step off-by-one on max_training_steps exit global_step is incremented before the max_training_steps check, so when the loop breaks it is max_training_steps + 1. Clamp it back to max_training_steps before exiting to ensure checkpoint metadata records the correct final step. --- skyrl/train/fully_async_trainer.py | 1 + skyrl/train/trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index b7e8e70670..f73d6cd8df 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -465,6 +465,7 @@ async def train(self): and self.global_step > self.cfg.trainer.max_training_steps ): logger.info(f"Reached max_training_steps={self.cfg.trainer.max_training_steps}, stopping early.") + self.global_step = self.cfg.trainer.max_training_steps for t in generator_tasks: t.cancel() await asyncio.gather(*generator_tasks, return_exceptions=True) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index d39314aeaa..6cfb1d6447 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -361,6 +361,7 @@ async def train(self): and self.global_step > self.cfg.trainer.max_training_steps ): logger.info(f"Reached max_training_steps={self.cfg.trainer.max_training_steps}, stopping early.") + self.global_step = self.cfg.trainer.max_training_steps stop_training = True break