diff --git a/examples/train/models/gemma_e2b.sh b/examples/train/models/gemma_e2b.sh new file mode 100644 index 0000000000..238bbfdc1f --- /dev/null +++ b/examples/train/models/gemma_e2b.sh @@ -0,0 +1,66 @@ +set -x + +# Colocated GRPO training+generation for Gemma-4-E2B on GSM8K. + +# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k +# export WANDB_API_KEY= +# bash examples/train/models/gemma_e2b.sh + +# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned + +# You can override the default values with e.g.: `NUM_GPUS=1 bash examples/train/models/gemma_e2b.sh`. + +: "${DATA_DIR:="$HOME/data/gsm8k"}" +: "${NUM_GPUS:=4}" +: "${LOGGER:=wandb}" # change to "console" to print to stdout + +: "${INFERENCE_BACKEND:=vllm}" + +# Gemma 4 E2B is multimodal; use language_model_only for text-only tasks +MODEL_NAME="google/gemma-4-E2B" +LANGUAGE_MODEL_ONLY=true + +uv run --isolated --extra fsdp --with transformers==5.5.4 -m skyrl.train.entrypoints.main_base \ + data.train_data="['$DATA_DIR/train.parquet']" \ + data.val_data="['$DATA_DIR/validation.parquet']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path=$MODEL_NAME \ + trainer.policy.language_model_only=$LANGUAGE_MODEL_ONLY \ + trainer.ref.language_model_only=$LANGUAGE_MODEL_ONLY \ + generator.inference_engine.language_model_only=$LANGUAGE_MODEL_ONLY \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.inference_engine.num_engines=$NUM_GPUS \ + generator.inference_engine.tensor_parallel_size=1 \ + trainer.epochs=20 \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=1024 \ + trainer.policy_mini_batch_size=256 \ + trainer.micro_forward_batch_size_per_gpu=64 \ + trainer.micro_train_batch_size_per_gpu=64 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=512 \ + generator.sampling_params.max_generate_length=1024 \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=true \ + generator.inference_engine.backend=$INFERENCE_BACKEND \ + generator.inference_engine.run_engines_locally=true \ + generator.inference_engine.weight_sync_backend=nccl \ + generator.inference_engine.async_engine=true \ + generator.batched=true \ + environment.env_class=gsm8k \ + generator.n_samples_per_prompt=5 \ + generator.inference_engine.gpu_memory_utilization=0.8 \ + trainer.logger="$LOGGER" \ + trainer.project_name="gsm8k" \ + trainer.run_name="gsm8k_test" \ + trainer.resume_mode=null \ + trainer.log_path="/tmp/skyrl-logs" \ + trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ + $@ diff --git a/examples/train/models/run_qwen3.5_0.8b.sh b/examples/train/models/qwen3.5_0.8b.sh similarity index 100% rename from examples/train/models/run_qwen3.5_0.8b.sh rename to examples/train/models/qwen3.5_0.8b.sh diff --git a/skyrl/backends/skyrl_train/distributed/fsdp_utils.py b/skyrl/backends/skyrl_train/distributed/fsdp_utils.py index 637f6440fb..e3a10f22cd 100644 --- a/skyrl/backends/skyrl_train/distributed/fsdp_utils.py +++ b/skyrl/backends/skyrl_train/distributed/fsdp_utils.py @@ -328,9 +328,13 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): if dist.get_rank() == 0: for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()): full_param = full_param.detach().cuda() - mesh = sharded_param.device_mesh dist.broadcast(full_param, src=0) - sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements) + + if hasattr(sharded_param, "device_mesh"): + sharded_tensor = distribute_tensor(full_param, sharded_param.device_mesh, sharded_param.placements) + else: + sharded_tensor = full_param + to_contiguous, casting_dtype = _infer_parameter_dtype( model, param_name, @@ -342,9 +346,13 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): else: for param_name, sharded_param in meta_sharded_sd.items(): full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype) - mesh = sharded_param.device_mesh dist.broadcast(full_tensor, src=0) - sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements) + + if hasattr(sharded_param, "device_mesh"): + sharded_tensor = distribute_tensor(full_tensor, sharded_param.device_mesh, sharded_param.placements) + else: + sharded_tensor = full_tensor + to_contiguous, casting_dtype = _infer_parameter_dtype( model, param_name, diff --git a/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py b/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py index 8249b30a7f..61bcf0154f 100644 --- a/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py +++ b/skyrl/backends/skyrl_train/inference_servers/vllm_worker.py @@ -40,6 +40,20 @@ def test_rpc(self, *args, **kwargs): """Test RPC call to worker.""" return args, kwargs + def debug_weight_checksum(self) -> dict: + """Compute checksums of model weights for debugging weight sync.""" + checksums = {} + for name, param in self.model_runner.model.named_parameters(): + data = param.data.float() + checksums[name] = { + "sum": data.sum().item(), + "absmax": data.abs().max().item(), + "shape": list(data.shape), + } + if len(checksums) >= 3: + break + return checksums + def init_weight_update_communicator(self, init_info: bytes): """ Initialize weight update communicator from init info. diff --git a/skyrl/backends/skyrl_train/weight_sync/broadcast_strategy.py b/skyrl/backends/skyrl_train/weight_sync/broadcast_strategy.py index 6c5e7357c8..e36e56825f 100644 --- a/skyrl/backends/skyrl_train/weight_sync/broadcast_strategy.py +++ b/skyrl/backends/skyrl_train/weight_sync/broadcast_strategy.py @@ -302,7 +302,8 @@ def receive_weights(self, request: BroadcastWeightUpdateRequest) -> Iterator[Tup offset = 0 for name, shape, size in zip(request.names, request.shapes, request.sizes): - yield name, packed[offset : offset + size].view(*shape) + chunk = packed[offset : offset + size] + yield name, chunk.view(*shape) if shape else chunk.squeeze() offset += size else: for name, dtype_str, shape in zip(request.names, request.dtypes, request.shapes): diff --git a/skyrl/backends/skyrl_train/weight_sync/cuda_ipc_strategy.py b/skyrl/backends/skyrl_train/weight_sync/cuda_ipc_strategy.py index c828c2172d..37dd2940a3 100644 --- a/skyrl/backends/skyrl_train/weight_sync/cuda_ipc_strategy.py +++ b/skyrl/backends/skyrl_train/weight_sync/cuda_ipc_strategy.py @@ -341,7 +341,8 @@ def receive_weights(self, request: CudaIpcWeightUpdateRequest) -> Iterator[Tuple offset = 0 for name, shape, size in zip(request.names, request.shapes, request.sizes): - yield name, packed_tensor[offset : offset + size].view(*shape) + chunk = packed_tensor[offset : offset + size] + yield name, chunk.view(*shape) if shape else chunk.squeeze() offset += size def teardown(self) -> None: diff --git a/skyrl/utils/tok.py b/skyrl/utils/tok.py index b328ee8334..abb7ca5a6c 100644 --- a/skyrl/utils/tok.py +++ b/skyrl/utils/tok.py @@ -1,8 +1,37 @@ """Tokenization related utilities""" +from loguru import logger from transformers import AutoTokenizer +def _try_load_chat_template_jinja(model_name_or_path: str) -> str | None: + """Attempt to download and read a chat_template.jinja from the model repo. + + Gemma 4 and some newer models ship the chat template as a separate .jinja + file rather than embedding it in tokenizer_config.json. When the tokenizer + doesn't pick it up automatically (e.g. base-model variants that lack the + file, or older transformers versions), we fall back to fetching it from the + corresponding ``-it`` (instruction-tuned) repo if the base repo doesn't + have one. + """ + try: + from huggingface_hub import hf_hub_download + except ImportError: + return None + + for repo_id in (model_name_or_path, f"{model_name_or_path}-it"): + try: + path = hf_hub_download(repo_id, "chat_template.jinja") + with open(path) as f: + template = f.read() + if template: + logger.info(f"Loaded chat_template.jinja from {repo_id}") + return template + except Exception: + continue + return None + + def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer: """Gets tokenizer for the given base model with the given parameters @@ -12,4 +41,15 @@ def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer: if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token + + if tokenizer.chat_template is None: + template = _try_load_chat_template_jinja(model_name_or_path) + if template: + tokenizer.chat_template = template + else: + logger.warning( + f"No chat_template found for {model_name_or_path}. " + f"tokenizer.apply_chat_template() will fail unless a template is provided." + ) + return tokenizer diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py index 58c48edc42..e46fea06d5 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py @@ -19,7 +19,8 @@ run_inference, ) -MODEL = "Qwen/Qwen2.5-0.5B-Instruct" +MODEL = "google/gemma-4-E2B" +# MODEL = "Qwen/Qwen2.5-0.5B-Instruct" MOE_MODEL = "hf-internal-testing/tiny-qwen3-moe" @@ -69,11 +70,11 @@ def get_test_actor_config(model: str) -> SkyRLTrainConfig: ], ids=[ "no_colocate_nccl_fsdp_vllm", - "colocate_nccl_fsdp_vllm", + "x_colocate_nccl_fsdp_vllm", "no_colocate_gloo_fsdp_vllm", "colocate_gloo_fsdp_vllm", "no_colocate_nccl_fsdp2_vllm", - "colocate_nccl_fsdp2_vllm", + "x_colocate_nccl_fsdp2_vllm", "colocate_nccl_fsdp2_vllm_mp", "non_colocated_nccl_fsdp2_vllm_mp", "colocate_nccl_fsdp2_vllm_dp", @@ -116,6 +117,8 @@ async def test_policy_local_engines_e2e( ) as engines: client, pg = engines.client, engines.pg + await client.sleep() + policy = init_worker_with_type( "policy", shared_pg=pg, @@ -131,6 +134,11 @@ async def test_policy_local_engines_e2e( "pass_through", "init_weight_sync_state", client, cfg.generator.inference_engine ) ) + # if colocate_all: + # await client.wake_up(tags=["weights", "kv_cache"]) + # else: + # await client.wake_up(tags=["weights"]) + await client.wake_up(tags=["weights"]) await client.reset_prefix_cache() ray.get( policy.async_run_ray_method( @@ -139,7 +147,7 @@ async def test_policy_local_engines_e2e( ) if colocate_all: policy.offload_to_cpu() - await client.wake_up() + await client.wake_up(tags=["kv_cache"]) sampling_params = get_sampling_params_for_backend( cfg.generator.inference_engine.backend, cfg.generator.sampling_params