From 7185ef156b5e483c8a3f5707c8a4e0d6fd428021 Mon Sep 17 00:00:00 2001 From: rishithayenumula Date: Wed, 29 Apr 2026 20:24:14 +0530 Subject: [PATCH 1/5] Fix eval crash when val set size is not divisible by train_batch_size - Add is_training flag to compute_prompt_mini_batch_boundaries() - Allow partial batches during evaluation - Use num_prompts instead of train_batch_size in boundary calculations - Keep strict validation during training for distributed correctness - Add 4 comprehensive tests for eval partial batch scenarios - Backward compatible: default is_training=True preserves training behavior --- skyrl/train/dataset/preprocess.py | 50 +++++++++++------ skyrl/train/trainer.py | 12 +++- tests/train/test_prompt_mini_batch.py | 79 +++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 21 deletions(-) diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index aea16711ff..dc15bf54c7 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple import torch -from jaxtyping import Float, Integer from transformers import AutoTokenizer logger = logging.getLogger(__name__) @@ -39,13 +38,13 @@ def convert_prompts_responses_to_batch_tensors( rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, max_seq_len: Optional[int] = None, ) -> Tuple[ - Float[torch.Tensor, "batch seq_len"], - Float[torch.Tensor, "batch seq_len"], - Float[torch.Tensor, "batch response_len"], - Float[torch.Tensor, "batch response_len"], - Float[torch.Tensor, "batch response_len"], - Optional[Float[torch.Tensor, "batch response_len"]], - Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], ]: """ Convert prompts and responses to batch tensors for training. @@ -196,6 +195,7 @@ def compute_prompt_mini_batch_boundaries( train_batch_size: int, is_stepwise: bool, n_samples_per_prompt: int, + is_training: bool = True, ) -> List[Tuple[int, int]]: """Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list. @@ -206,10 +206,12 @@ def compute_prompt_mini_batch_boundaries( train_batch_size: Number of prompts in a training batch. For sanity check. is_stepwise: Whether the training is step-wise. For sanity check. n_samples_per_prompt: how many samples per prompt. For sanity check. + is_training: Whether this is a training batch (strict validation) or eval batch (allows partial batches). + Defaults to True for backward compatibility. Returns: List of (start, end) indices of the mini-batches. The length of the list is the number of - mini-batches, guaranteed to be `train_batch_size // mini_batch_size` regardless of whether - the training is step-wise or not. + mini-batches, guaranteed to be `train_batch_size // mini_batch_size` during training, but may differ + during evaluation if the final batch is partial. Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly ``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible @@ -244,23 +246,35 @@ def compute_prompt_mini_batch_boundaries( prompt_end_indices.append(i) prompt_end_indices.append(len(uids)) - # seen_uids should equal to the number of prompts and equal to `train_batch_size` + # Check that num_prompts matches expected batch size num_prompts = len(prompt_end_indices) - assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size - assert train_batch_size % mini_batch_size == 0 - - # Compute boundaries. + if is_training: + assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size, ( + f"Expected {train_batch_size} prompts in training batch, got {num_prompts}." + ) + assert train_batch_size % mini_batch_size == 0 + else: + if num_prompts != train_batch_size: + logger.warning( + f"Partial batch detected during eval: got {num_prompts} prompts but " + f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." + ) + + # Compute boundaries. Handle partial batches during eval. boundaries: List[Tuple[int, int]] = [] start_seq = 0 + for i in range(0, num_prompts, mini_batch_size): - end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index + end_prompt_idx = min(i + mini_batch_size - 1, num_prompts - 1) end_seq = prompt_end_indices[end_prompt_idx] boundaries.append((start_seq, end_seq)) start_seq = end_seq - assert len(boundaries) == train_batch_size // mini_batch_size + + if is_training: + assert len(boundaries) == train_batch_size // mini_batch_size # Assert that the mini-batch boundaries are uniform for non-step-wise training. - if not is_stepwise: + if not is_stepwise and is_training: expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size for i, (start, end) in enumerate(boundaries): assert start == i * expected_num_seq_in_mini_batch diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 61655f0f65..1999476ba9 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -592,13 +592,17 @@ def init_weight_sync_state(self): self.dispatch.init_weight_sync_state(self.inference_engine_client) logger.info("Initialized weight sync state for policy model and inference engines.") - def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch: + def convert_to_training_input( + self, generator_output: GeneratorOutput, uids: List[str], is_training: bool = True + ) -> TrainingInputBatch: """Converts lists to a padded batch of tensors for training Args: generator_output (GeneratorOutput): Generated rollouts and associated data. uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same order as `generator_output`. Used to identify which prompt each generated rollout belongs to. + is_training (bool): Whether this batch is for training (strict batch size) or evaluation + (allows partial batches). Defaults to True for backward compatibility. Returns: training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the order of `generator_output` and hence `uids`. @@ -680,11 +684,13 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt is_stepwise = self.cfg.generator.step_wise_trajectories training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt, + is_training=is_training ) if self.cfg.trainer.critic.model.path is not None: training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt, + is_training=is_training ) # 5. Record metadata and metrics. diff --git a/tests/train/test_prompt_mini_batch.py b/tests/train/test_prompt_mini_batch.py index 0ba6c2e684..4da7288177 100644 --- a/tests/train/test_prompt_mini_batch.py +++ b/tests/train/test_prompt_mini_batch.py @@ -199,6 +199,85 @@ def test_same_step_count_as_non_stepwise(self): assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2 + def test_eval_partial_batch_nonstepwise(self): + """Test eval mode with partial batches during non-stepwise training. + + This addresses the issue where evaluation crashes when val set size is + not divisible by train_batch_size. With is_training=False, partial + batches should be allowed. + """ + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + + # Only 3 prompts instead of 4 (partial batch) + uids = ["p0", "p0", "p1", "p1", "p2", "p2"] + + # Should work fine with is_training=False + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: + # First mini-batch: prompts 0-1 (sequences 0-4) + # Second mini-batch: prompt 2 (sequences 4-6) + assert boundaries == [(0, 4), (4, 6)] + + def test_eval_partial_batch_single_minibatch(self): + """Test eval mode with partial batch that fits in single mini-batch.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + + # Only 1 prompt instead of 4 (very partial batch) + uids = ["p0", "p0"] + + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 1 prompt and mini_batch_size=2, we get 1 mini-batch + assert boundaries == [(0, 2)] + + def test_eval_rejects_noncontiguous_uids(self): + """Test that eval mode still enforces contiguous uids.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + # Non-contiguous uids: p0 appears at index 0-1 and 4-5 + uids = ["p0", "p0", "p1", "p1", "p0", "p0"] + + with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"): + compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + + def test_eval_stepwise_partial_batch(self): + """Test eval mode with stepwise training and partial batch.""" + mini_batch_size = 2 + train_batch_size = 4 + spp = 2 + is_stepwise = True + + # Only 3 prompts instead of 4 + uids = _make_uids_stepwise( + [ + ("p0", 2, [3, 2]), # 5 seqs + ("p1", 2, [1, 4]), # 5 seqs + ("p2", 2, [2, 1]), # 3 seqs + ] + ) + + # Should work fine with is_training=False + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: + # First: prompts 0-1 (sequences 0-10) + # Second: prompt 2 (sequences 10-13) + assert boundaries == [(0, 10), (10, 13)] + # Non-step-wise boundaries should be uniform assert non_stepwise_bounds == [(0, 640), (640, 1280)] From 4adaadeb1f89b3d739dc8a72d02737a673eaa1fc Mon Sep 17 00:00:00 2001 From: rishithayenumula Date: Wed, 29 Apr 2026 20:39:25 +0530 Subject: [PATCH 2/5] Fix test assertion placement and reduce eval partial-batch log noise --- skyrl/train/dataset/preprocess.py | 2 +- tests/train/test_prompt_mini_batch.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index dc15bf54c7..b180798930 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -255,7 +255,7 @@ def compute_prompt_mini_batch_boundaries( assert train_batch_size % mini_batch_size == 0 else: if num_prompts != train_batch_size: - logger.warning( + logger.info( f"Partial batch detected during eval: got {num_prompts} prompts but " f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." ) diff --git a/tests/train/test_prompt_mini_batch.py b/tests/train/test_prompt_mini_batch.py index 4da7288177..e55bfc8be0 100644 --- a/tests/train/test_prompt_mini_batch.py +++ b/tests/train/test_prompt_mini_batch.py @@ -198,6 +198,8 @@ def test_same_step_count_as_non_stepwise(self): ) assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2 + # Non-step-wise boundaries should be uniform + assert non_stepwise_bounds == [(0, 640), (640, 1280)] def test_eval_partial_batch_nonstepwise(self): """Test eval mode with partial batches during non-stepwise training. @@ -278,9 +280,6 @@ def test_eval_stepwise_partial_batch(self): # Second: prompt 2 (sequences 10-13) assert boundaries == [(0, 10), (10, 13)] - # Non-step-wise boundaries should be uniform - assert non_stepwise_bounds == [(0, 640), (640, 1280)] - # --------------------------------------------------------------------------- # Tests for MeshDispatch.stage_chunks From 6d7f74c21c80cff70b7cbbefbc3288bdcf23f9c1 Mon Sep 17 00:00:00 2001 From: rishithayenumula Date: Wed, 29 Apr 2026 21:05:27 +0530 Subject: [PATCH 3/5] Trigger CI rerun after review updates From dbb781dbb078185333275571ef97d3c8f4fff104 Mon Sep 17 00:00:00 2001 From: rishithayenumula Date: Wed, 29 Apr 2026 21:07:00 +0530 Subject: [PATCH 4/5] Restore jaxtyping tensor annotations in preprocess --- skyrl/train/dataset/preprocess.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index b180798930..4286f7daf1 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -2,6 +2,7 @@ from typing import List, Optional, Tuple import torch +from jaxtyping import Float, Integer from transformers import AutoTokenizer logger = logging.getLogger(__name__) @@ -38,13 +39,13 @@ def convert_prompts_responses_to_batch_tensors( rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, max_seq_len: Optional[int] = None, ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], + Float[torch.Tensor, "batch seq_len"], + Float[torch.Tensor, "batch seq_len"], + Float[torch.Tensor, "batch response_len"], + Float[torch.Tensor, "batch response_len"], + Float[torch.Tensor, "batch response_len"], + Optional[Float[torch.Tensor, "batch response_len"]], + Optional[Integer[torch.Tensor, "batch seq_len layer_num topk"]], ]: """ Convert prompts and responses to batch tensors for training. From bde201e84de50d5146e68d3ed5b13185ddd79c15 Mon Sep 17 00:00:00 2001 From: rishithayenumula Date: Wed, 29 Apr 2026 21:25:21 +0530 Subject: [PATCH 5/5] fix formatting via pre-commit (black) --- skyrl/train/dataset/preprocess.py | 6 +++--- skyrl/train/trainer.py | 16 ++++++++++++---- tests/train/test_prompt_mini_batch.py | 16 ++++++++-------- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index 4286f7daf1..ece77a3b84 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -250,9 +250,9 @@ def compute_prompt_mini_batch_boundaries( # Check that num_prompts matches expected batch size num_prompts = len(prompt_end_indices) if is_training: - assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size, ( - f"Expected {train_batch_size} prompts in training batch, got {num_prompts}." - ) + assert ( + num_prompts == train_batch_size and len(seen_uids) == train_batch_size + ), f"Expected {train_batch_size} prompts in training batch, got {num_prompts}." assert train_batch_size % mini_batch_size == 0 else: if num_prompts != train_batch_size: diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 1999476ba9..b78c28645b 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -684,13 +684,21 @@ def convert_to_training_input( n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt is_stepwise = self.cfg.generator.step_wise_trajectories training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt, - is_training=is_training + uids, + self.cfg.trainer.policy_mini_batch_size, + train_batch_size, + is_stepwise, + n_samples_per_prompt, + is_training=is_training, ) if self.cfg.trainer.critic.model.path is not None: training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt, - is_training=is_training + uids, + self.cfg.trainer.critic_mini_batch_size, + train_batch_size, + is_stepwise, + n_samples_per_prompt, + is_training=is_training, ) # 5. Record metadata and metrics. diff --git a/tests/train/test_prompt_mini_batch.py b/tests/train/test_prompt_mini_batch.py index e55bfc8be0..4d4967dd61 100644 --- a/tests/train/test_prompt_mini_batch.py +++ b/tests/train/test_prompt_mini_batch.py @@ -203,7 +203,7 @@ def test_same_step_count_as_non_stepwise(self): def test_eval_partial_batch_nonstepwise(self): """Test eval mode with partial batches during non-stepwise training. - + This addresses the issue where evaluation crashes when val set size is not divisible by train_batch_size. With is_training=False, partial batches should be allowed. @@ -212,10 +212,10 @@ def test_eval_partial_batch_nonstepwise(self): spp = 2 is_stepwise = False mini_batch_size = 2 - + # Only 3 prompts instead of 4 (partial batch) uids = ["p0", "p0", "p1", "p1", "p2", "p2"] - + # Should work fine with is_training=False boundaries = compute_prompt_mini_batch_boundaries( uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False @@ -231,10 +231,10 @@ def test_eval_partial_batch_single_minibatch(self): spp = 2 is_stepwise = False mini_batch_size = 2 - + # Only 1 prompt instead of 4 (very partial batch) uids = ["p0", "p0"] - + boundaries = compute_prompt_mini_batch_boundaries( uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False ) @@ -249,7 +249,7 @@ def test_eval_rejects_noncontiguous_uids(self): mini_batch_size = 2 # Non-contiguous uids: p0 appears at index 0-1 and 4-5 uids = ["p0", "p0", "p1", "p1", "p0", "p0"] - + with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"): compute_prompt_mini_batch_boundaries( uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False @@ -261,7 +261,7 @@ def test_eval_stepwise_partial_batch(self): train_batch_size = 4 spp = 2 is_stepwise = True - + # Only 3 prompts instead of 4 uids = _make_uids_stepwise( [ @@ -270,7 +270,7 @@ def test_eval_stepwise_partial_batch(self): ("p2", 2, [2, 1]), # 3 seqs ] ) - + # Should work fine with is_training=False boundaries = compute_prompt_mini_batch_boundaries( uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False