diff --git a/.gitignore b/.gitignore index 86e37afa78..3012903854 100644 --- a/.gitignore +++ b/.gitignore @@ -1,125 +1,5 @@ +.venv/ +.venv311/ __pycache__/ -/wandb/ -**/*.egg-info/ -# hydra logs -/outputs/ -/data/lcb - -# MkDocs build output (generated during build) -docs/public/api-ref/ - -# Documentation cache -.doctrees/ -.cache/ -.pytest_cache/ - -# NOTE (sumanthrh): Don't add .env to gitignore. .env file when passed to uv is used to set env vars for each ray worker process. -# If it's in .gitignore then it won't be a part of the working directory shipped by uv and your env vars will not be set. -# This will just appear as a warning (silent failure) and you're gonna have a bad time. -# .env - -# .env files inside directories can be ignored -/skyrl-gym/.env - -/skyrl-gym/.venv - -# build -/skyrl-gym/build -/skyrl-gym/dist - -*.log -nohup.out -tensorboard_log/ - -# SQLite database files -*.db - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -!docs/lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Jupyter Notebook -.ipynb_checkpoints - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# MkDocs build output -site/ - -# IDEs and editors -.idea/ -.vscode/ - -# OS generated files -.DS_Store -Thumbs.db - -# Hydra outputs -outputs/ - -# Local artifacts -tinker.db - -# Alembic - don't track pycache -tx/tinker/alembic/__pycache__/ - -# SQLite databases (tracked in git by default, but ignore if created locally) -*.db -*.db-journal -*.db-wal -*.db-shm +*.pyc +*.egg-info/ \ No newline at end of file diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index aea16711ff..ece77a3b84 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -196,6 +196,7 @@ def compute_prompt_mini_batch_boundaries( train_batch_size: int, is_stepwise: bool, n_samples_per_prompt: int, + is_training: bool = True, ) -> List[Tuple[int, int]]: """Compute mini-batch ``(start, end)`` slices from a flat ``uids`` list. @@ -206,10 +207,12 @@ def compute_prompt_mini_batch_boundaries( train_batch_size: Number of prompts in a training batch. For sanity check. is_stepwise: Whether the training is step-wise. For sanity check. n_samples_per_prompt: how many samples per prompt. For sanity check. + is_training: Whether this is a training batch (strict validation) or eval batch (allows partial batches). + Defaults to True for backward compatibility. Returns: List of (start, end) indices of the mini-batches. The length of the list is the number of - mini-batches, guaranteed to be `train_batch_size // mini_batch_size` regardless of whether - the training is step-wise or not. + mini-batches, guaranteed to be `train_batch_size // mini_batch_size` during training, but may differ + during evaluation if the final batch is partial. Consecutive equal entries in ``uids`` belong to the same prompt. Each mini batch spans exactly ``mini_batch_size`` prompts (the last may be smaller if the total prompt count is not divisible @@ -244,23 +247,35 @@ def compute_prompt_mini_batch_boundaries( prompt_end_indices.append(i) prompt_end_indices.append(len(uids)) - # seen_uids should equal to the number of prompts and equal to `train_batch_size` + # Check that num_prompts matches expected batch size num_prompts = len(prompt_end_indices) - assert num_prompts == train_batch_size and len(seen_uids) == train_batch_size - assert train_batch_size % mini_batch_size == 0 - - # Compute boundaries. + if is_training: + assert ( + num_prompts == train_batch_size and len(seen_uids) == train_batch_size + ), f"Expected {train_batch_size} prompts in training batch, got {num_prompts}." + assert train_batch_size % mini_batch_size == 0 + else: + if num_prompts != train_batch_size: + logger.info( + f"Partial batch detected during eval: got {num_prompts} prompts but " + f"train_batch_size={train_batch_size}. Using actual batch size for mini-batch boundaries." + ) + + # Compute boundaries. Handle partial batches during eval. boundaries: List[Tuple[int, int]] = [] start_seq = 0 + for i in range(0, num_prompts, mini_batch_size): - end_prompt_idx = i + mini_batch_size - 1 # i + mini_batch_size is next mini-batch's first prompt's end index + end_prompt_idx = min(i + mini_batch_size - 1, num_prompts - 1) end_seq = prompt_end_indices[end_prompt_idx] boundaries.append((start_seq, end_seq)) start_seq = end_seq - assert len(boundaries) == train_batch_size // mini_batch_size + + if is_training: + assert len(boundaries) == train_batch_size // mini_batch_size # Assert that the mini-batch boundaries are uniform for non-step-wise training. - if not is_stepwise: + if not is_stepwise and is_training: expected_num_seq_in_mini_batch = n_samples_per_prompt * mini_batch_size for i, (start, end) in enumerate(boundaries): assert start == i * expected_num_seq_in_mini_batch diff --git a/skyrl/train/fully_async_trainer.py b/skyrl/train/fully_async_trainer.py index 2286e57f8d..16073845bf 100644 --- a/skyrl/train/fully_async_trainer.py +++ b/skyrl/train/fully_async_trainer.py @@ -384,26 +384,21 @@ async def train(self): for step_idx in range(self.global_step, (1 + epoch) * self.num_steps_per_epoch + 1): with Timer("step", self.all_timings): - # 1. Wait until we have enough groups buffered. + # 1. Non-blocking streaming training: process mini-batch when buffer has enough data. cur_generation_group_mini_batch: List[GeneratedOutputGroup] = [] with Timer("wait_for_generation_buffer", self.all_timings): - buffer_pbar = tqdm( - total=self.mini_batch_size, - initial=0, - desc="Generation Buffer Progress", - position=1, - ) - # NOTE(Charlie): we currently trim the train_dataloader to make it perfectly divisible by - # self.mini_batch_size, and assume that all trajectories succeed (just like sync training), - # so we always get a full mini-batch. Otherwise (e.g. want to drop stale trajectories), we - # should handle the case where the dataloader is exhausted and the buffer is empty, or - # else this loop will never exit. - while len(cur_generation_group_mini_batch) < self.mini_batch_size: + while generation_output_group_buffer.qsize() < self.mini_batch_size: + # Sleep briefly to avoid busy waiting while generation workers keep running. + await asyncio.sleep(0.01) + logger.info(f"Buffer size: {generation_output_group_buffer.qsize()}") + for _ in range(self.mini_batch_size): # We do finish-time FIFO here (not schedule-time FIFO) - cur_generation_group_mini_batch.append(await generation_output_group_buffer.get()) - buffer_pbar.update(1) - buffer_pbar.set_postfix({"buffer qsize": generation_output_group_buffer.qsize()}) - buffer_pbar.close() + try: + cur_generation_group_mini_batch.append(generation_output_group_buffer.get_nowait()) + except asyncio.QueueEmpty as e: + raise AssertionError( + "Generation buffer unexpectedly drained while collecting a mini-batch." + ) from e # 2. Post-process the generated groups, aggregating to a single GeneratorOutput, and convert to training format. with Timer("convert_to_training_input", self.all_timings): @@ -593,11 +588,9 @@ async def _run_generate_for_a_group_loop(self, generation_output_group_buffer: a await self._staleness_manager.on_rollout_accepted() except asyncio.CancelledError: # If a slot was acquired but we exit early, release running count - try: - if "slot_acquired" in locals() and slot_acquired: - raise RuntimeError("Generation workers should only be cancelled when they finish running.") - finally: - return + if "slot_acquired" in locals() and slot_acquired: + raise RuntimeError("Generation workers should only be cancelled when they finish running.") + return except Exception as e: logger.error(f"Generator worker errored out with exception: {e}") logger.error(f"Traceback: \n{traceback.format_exc()}") diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 61655f0f65..b78c28645b 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -592,13 +592,17 @@ def init_weight_sync_state(self): self.dispatch.init_weight_sync_state(self.inference_engine_client) logger.info("Initialized weight sync state for policy model and inference engines.") - def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch: + def convert_to_training_input( + self, generator_output: GeneratorOutput, uids: List[str], is_training: bool = True + ) -> TrainingInputBatch: """Converts lists to a padded batch of tensors for training Args: generator_output (GeneratorOutput): Generated rollouts and associated data. uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same order as `generator_output`. Used to identify which prompt each generated rollout belongs to. + is_training (bool): Whether this batch is for training (strict batch size) or evaluation + (allows partial batches). Defaults to True for backward compatibility. Returns: training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the order of `generator_output` and hence `uids`. @@ -680,11 +684,21 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt is_stepwise = self.cfg.generator.step_wise_trajectories training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + uids, + self.cfg.trainer.policy_mini_batch_size, + train_batch_size, + is_stepwise, + n_samples_per_prompt, + is_training=is_training, ) if self.cfg.trainer.critic.model.path is not None: training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries( - uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt + uids, + self.cfg.trainer.critic_mini_batch_size, + train_batch_size, + is_stepwise, + n_samples_per_prompt, + is_training=is_training, ) # 5. Record metadata and metrics. diff --git a/tests/train/test_prompt_mini_batch.py b/tests/train/test_prompt_mini_batch.py index 0ba6c2e684..4d4967dd61 100644 --- a/tests/train/test_prompt_mini_batch.py +++ b/tests/train/test_prompt_mini_batch.py @@ -198,10 +198,88 @@ def test_same_step_count_as_non_stepwise(self): ) assert len(stepwise_bounds) == len(non_stepwise_bounds) == 2 - # Non-step-wise boundaries should be uniform assert non_stepwise_bounds == [(0, 640), (640, 1280)] + def test_eval_partial_batch_nonstepwise(self): + """Test eval mode with partial batches during non-stepwise training. + + This addresses the issue where evaluation crashes when val set size is + not divisible by train_batch_size. With is_training=False, partial + batches should be allowed. + """ + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + + # Only 3 prompts instead of 4 (partial batch) + uids = ["p0", "p0", "p1", "p1", "p2", "p2"] + + # Should work fine with is_training=False + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: + # First mini-batch: prompts 0-1 (sequences 0-4) + # Second mini-batch: prompt 2 (sequences 4-6) + assert boundaries == [(0, 4), (4, 6)] + + def test_eval_partial_batch_single_minibatch(self): + """Test eval mode with partial batch that fits in single mini-batch.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + + # Only 1 prompt instead of 4 (very partial batch) + uids = ["p0", "p0"] + + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 1 prompt and mini_batch_size=2, we get 1 mini-batch + assert boundaries == [(0, 2)] + + def test_eval_rejects_noncontiguous_uids(self): + """Test that eval mode still enforces contiguous uids.""" + train_batch_size = 4 + spp = 2 + is_stepwise = False + mini_batch_size = 2 + # Non-contiguous uids: p0 appears at index 0-1 and 4-5 + uids = ["p0", "p0", "p1", "p1", "p0", "p0"] + + with pytest.raises(AssertionError, match="uid 'p0' appears in non-contiguous positions"): + compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + + def test_eval_stepwise_partial_batch(self): + """Test eval mode with stepwise training and partial batch.""" + mini_batch_size = 2 + train_batch_size = 4 + spp = 2 + is_stepwise = True + + # Only 3 prompts instead of 4 + uids = _make_uids_stepwise( + [ + ("p0", 2, [3, 2]), # 5 seqs + ("p1", 2, [1, 4]), # 5 seqs + ("p2", 2, [2, 1]), # 3 seqs + ] + ) + + # Should work fine with is_training=False + boundaries = compute_prompt_mini_batch_boundaries( + uids, mini_batch_size, train_batch_size, is_stepwise, spp, is_training=False + ) + # With 3 prompts and mini_batch_size=2, we get 2 mini-batches: + # First: prompts 0-1 (sequences 0-10) + # Second: prompt 2 (sequences 10-13) + assert boundaries == [(0, 10), (10, 13)] + # --------------------------------------------------------------------------- # Tests for MeshDispatch.stage_chunks