Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,9 @@ class TrainerConfig(BaseConfig):
"""Path for exported artifacts (HF models, debug dumps, etc.)."""
bf16: bool = True
epochs: int = 1
max_training_steps: Optional[int] = None
"""If set, stop training after this many steps regardless of epochs or dataset size.
Useful for CI smoke tests and quick validation runs."""
update_epochs_per_batch: int = 1
"""Number of gradient update passes over each training batch."""
train_batch_size: int = 1024
Expand Down
7 changes: 7 additions & 0 deletions skyrl/train/config/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
dummy_run_full_ctx: bool = False # Skip real data; fabricate full-context sequences
dummy_run_max_steps: int = 5 # Number of steps to run in dummy mode

# ---- CI / smoke test support ----
max_training_steps: Optional[int] = None
"""If set, stop training after this many steps regardless of num_steps or num_epochs.
Useful for CI smoke tests and quick validation runs."""


# ---------------------------------------------------------------------------
# Bridge: SFTConfig -> SkyRLTrainConfig
Expand Down Expand Up @@ -225,6 +230,8 @@ def validate_sft_cfg(cfg: SFTConfig) -> None:
raise ValueError("model.path must be set")
if cfg.dummy_run_full_ctx and cfg.dummy_run_max_steps <= 0:
raise ValueError(f"dummy_run_max_steps must be > 0, got {cfg.dummy_run_max_steps}")
if cfg.max_training_steps is not None and cfg.max_training_steps <= 0:
raise ValueError(f"max_training_steps must be > 0, got {cfg.max_training_steps}")

# Eval config
if cfg.eval_interval < 0:
Expand Down
18 changes: 18 additions & 0 deletions skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def _build_train_dataloader_and_compute_training_steps(self):
self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True, is_fully_async=True)
self.num_steps_per_epoch = len(self.train_dataloader) // self.mini_batch_size
self.total_training_steps = self.num_steps_per_epoch * self.cfg.trainer.epochs
if self.cfg.trainer.max_training_steps is not None:
self.total_training_steps = min(self.total_training_steps, self.cfg.trainer.max_training_steps)
logger.info(f"Length of train_dataloader: {len(self.train_dataloader)}")
logger.info(f"Number of steps per epoch: {self.num_steps_per_epoch}")
logger.info(f"Total training steps: {self.total_training_steps}")
Expand Down Expand Up @@ -368,6 +370,7 @@ async def train(self):
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Step Progress")
start_epoch = self.global_step // self.num_steps_per_epoch
self.global_step += 1 # start training at global_step 1
stop_training = False
for epoch in range(start_epoch, self.cfg.trainer.epochs):
# 0. Per-epoch prologue. Note that we do not do any cross-epoch asynchrony here.

Expand Down Expand Up @@ -457,6 +460,18 @@ async def train(self):
self.all_timings = {}
self.global_step += 1

if (
self.cfg.trainer.max_training_steps is not None
and self.global_step > self.cfg.trainer.max_training_steps
):
logger.info(f"Reached max_training_steps={self.cfg.trainer.max_training_steps}, stopping early.")
self.global_step = self.cfg.trainer.max_training_steps
for t in generator_tasks:
t.cancel()
await asyncio.gather(*generator_tasks, return_exceptions=True)
stop_training = True
break
Comment thread
dinhxuanvu marked this conversation as resolved.

# 8. Notify generation workers that the capacity has increased, unblocking them.
await self._staleness_manager.notify_capacity_change(self.global_step)
steps_completed_in_epoch = (self.global_step - 1) % self.num_steps_per_epoch
Expand All @@ -470,6 +485,9 @@ async def train(self):
f"{actual_consumed_in_epoch} != {expected_consumed_in_epoch}"
)

if stop_training:
break

# 9. Per-epoch epilogue.
if self.cfg.trainer.update_ref_every_epoch and self.ref_model is not None:
with Timer("update_ref_with_policy", self.all_timings):
Expand Down
14 changes: 12 additions & 2 deletions skyrl/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,14 @@ def _init_workers(self):
num_training_steps = (
self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps
)
# num_steps may be None when num_epochs is used; the worker will use its
# default (large value) for the LR scheduler in that case.
if self.sft_cfg.max_training_steps is not None:
num_training_steps = (
self.sft_cfg.max_training_steps
if num_training_steps is None
else min(num_training_steps, self.sft_cfg.max_training_steps)
)
# num_steps may be None when num_epochs is used; without an explicit cap,
# the worker will use its default large value for the LR scheduler.
ray.get(
actor_group.async_init_model(
self.sft_cfg.model.path,
Expand Down Expand Up @@ -991,6 +997,10 @@ def train(self):
f"ceil({len(tokenized)} / {batch_size}) * {self.sft_cfg.num_epochs} = {num_steps} steps"
)

if self.sft_cfg.max_training_steps is not None:
num_steps = min(num_steps, self.sft_cfg.max_training_steps)
logger.info(f"Capping training at max_training_steps={self.sft_cfg.max_training_steps}")

# Early validation: dataset must have at least batch_size examples
if len(tokenized) < batch_size:
raise ValueError(
Expand Down
15 changes: 15 additions & 0 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def _build_train_dataloader_and_compute_training_steps(self):
if self.train_dataset is not None:
self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True)
self.total_training_steps = len(self.train_dataloader) * self.cfg.trainer.epochs
if self.cfg.trainer.max_training_steps is not None:
self.total_training_steps = min(self.total_training_steps, self.cfg.trainer.max_training_steps)

@torch.no_grad()
async def eval(self) -> Dict[str, float]:
Expand Down Expand Up @@ -215,6 +217,7 @@ async def train(self):
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
start_epoch = self.global_step // len(self.train_dataloader)
self.global_step += 1 # start training at global_step 1
stop_training = False
for epoch in range(start_epoch, self.cfg.trainer.epochs):
for _, rand_prompts in enumerate(self.train_dataloader):
with Timer("step", self.all_timings):
Expand Down Expand Up @@ -353,8 +356,20 @@ async def train(self):

self.global_step += 1

if (
self.cfg.trainer.max_training_steps is not None
and self.global_step > self.cfg.trainer.max_training_steps
):
logger.info(f"Reached max_training_steps={self.cfg.trainer.max_training_steps}, stopping early.")
self.global_step = self.cfg.trainer.max_training_steps
stop_training = True
break
Comment thread
dinhxuanvu marked this conversation as resolved.

del training_input, generator_output

if stop_training:
break

pbar.close()
if self.colocate_all:
await self.inference_engine_client.sleep()
Expand Down
4 changes: 4 additions & 0 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def validate_cfg(cfg: SkyRLTrainConfig):
)
cfg.trainer.strategy = "fsdp"

if cfg.trainer.max_training_steps is not None:
if cfg.trainer.max_training_steps <= 0:
raise ValueError(f"max_training_steps must be > 0, got {cfg.trainer.max_training_steps}")

# Validate generation config separately
validate_generator_cfg(cfg)
from skyrl.backends.skyrl_train.utils.ppo_utils import (
Expand Down
Loading
Loading