Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/train/models/gemma_e2b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
set -x

# Colocated GRPO training+generation for Gemma-4-E2B on GSM8K.

# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/models/gemma_e2b.sh

# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned

# You can override the default values with e.g.: `NUM_GPUS=1 bash examples/train/models/gemma_e2b.sh`.

: "${DATA_DIR:="$HOME/data/gsm8k"}"
: "${NUM_GPUS:=4}"
: "${LOGGER:=wandb}" # change to "console" to print to stdout

: "${INFERENCE_BACKEND:=vllm}"

# Gemma 4 E2B is multimodal; use language_model_only for text-only tasks
MODEL_NAME="google/gemma-4-E2B"
LANGUAGE_MODEL_ONLY=true

uv run --isolated --extra fsdp --with transformers==5.5.4 -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.policy.language_model_only=$LANGUAGE_MODEL_ONLY \
trainer.ref.language_model_only=$LANGUAGE_MODEL_ONLY \
generator.inference_engine.language_model_only=$LANGUAGE_MODEL_ONLY \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_GPUS \
generator.inference_engine.tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=1024 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=64 \
trainer.micro_train_batch_size_per_gpu=64 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.inference_engine.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k" \
trainer.run_name="gsm8k_test" \
trainer.resume_mode=null \
trainer.log_path="/tmp/skyrl-logs" \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
$@
16 changes: 12 additions & 4 deletions skyrl/backends/skyrl_train/distributed/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,13 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dist.get_rank() == 0:
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
full_param = full_param.detach().cuda()
mesh = sharded_param.device_mesh
dist.broadcast(full_param, src=0)
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)

if hasattr(sharded_param, "device_mesh"):
sharded_tensor = distribute_tensor(full_param, sharded_param.device_mesh, sharded_param.placements)
else:
sharded_tensor = full_param

to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
Expand All @@ -342,9 +346,13 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
else:
for param_name, sharded_param in meta_sharded_sd.items():
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
mesh = sharded_param.device_mesh
dist.broadcast(full_tensor, src=0)
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)

if hasattr(sharded_param, "device_mesh"):
sharded_tensor = distribute_tensor(full_tensor, sharded_param.device_mesh, sharded_param.placements)
else:
sharded_tensor = full_tensor

to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
Expand Down
14 changes: 14 additions & 0 deletions skyrl/backends/skyrl_train/inference_servers/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ def test_rpc(self, *args, **kwargs):
"""Test RPC call to worker."""
return args, kwargs

def debug_weight_checksum(self) -> dict:
"""Compute checksums of model weights for debugging weight sync."""
checksums = {}
for name, param in self.model_runner.model.named_parameters():
data = param.data.float()
checksums[name] = {
"sum": data.sum().item(),
"absmax": data.abs().max().item(),
"shape": list(data.shape),
}
if len(checksums) >= 3:
break
return checksums

def init_weight_update_communicator(self, init_info: bytes):
"""
Initialize weight update communicator from init info.
Expand Down
3 changes: 2 additions & 1 deletion skyrl/backends/skyrl_train/weight_sync/broadcast_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ def receive_weights(self, request: BroadcastWeightUpdateRequest) -> Iterator[Tup

offset = 0
for name, shape, size in zip(request.names, request.shapes, request.sizes):
yield name, packed[offset : offset + size].view(*shape)
chunk = packed[offset : offset + size]
yield name, chunk.view(*shape) if shape else chunk.squeeze()
offset += size
else:
for name, dtype_str, shape in zip(request.names, request.dtypes, request.shapes):
Expand Down
3 changes: 2 additions & 1 deletion skyrl/backends/skyrl_train/weight_sync/cuda_ipc_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def receive_weights(self, request: CudaIpcWeightUpdateRequest) -> Iterator[Tuple

offset = 0
for name, shape, size in zip(request.names, request.shapes, request.sizes):
yield name, packed_tensor[offset : offset + size].view(*shape)
chunk = packed_tensor[offset : offset + size]
yield name, chunk.view(*shape) if shape else chunk.squeeze()
offset += size

def teardown(self) -> None:
Expand Down
40 changes: 40 additions & 0 deletions skyrl/utils/tok.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
"""Tokenization related utilities"""

from loguru import logger
from transformers import AutoTokenizer


def _try_load_chat_template_jinja(model_name_or_path: str) -> str | None:
"""Attempt to download and read a chat_template.jinja from the model repo.

Gemma 4 and some newer models ship the chat template as a separate .jinja
file rather than embedding it in tokenizer_config.json. When the tokenizer
doesn't pick it up automatically (e.g. base-model variants that lack the
file, or older transformers versions), we fall back to fetching it from the
corresponding ``-it`` (instruction-tuned) repo if the base repo doesn't
have one.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
return None

for repo_id in (model_name_or_path, f"{model_name_or_path}-it"):
try:
path = hf_hub_download(repo_id, "chat_template.jinja")
with open(path) as f:
template = f.read()
if template:
logger.info(f"Loaded chat_template.jinja from {repo_id}")
return template
except Exception:
continue
return None


def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer:
"""Gets tokenizer for the given base model with the given parameters

Expand All @@ -12,4 +41,15 @@ def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer:
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

if tokenizer.chat_template is None:
template = _try_load_chat_template_jinja(model_name_or_path)
if template:
tokenizer.chat_template = template
else:
logger.warning(
f"No chat_template found for {model_name_or_path}. "
f"tokenizer.apply_chat_template() will fail unless a template is provided."
)

return tokenizer
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
run_inference,
)

MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
MODEL = "google/gemma-4-E2B"
# MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
MOE_MODEL = "hf-internal-testing/tiny-qwen3-moe"


Expand Down Expand Up @@ -69,11 +70,11 @@ def get_test_actor_config(model: str) -> SkyRLTrainConfig:
],
ids=[
"no_colocate_nccl_fsdp_vllm",
"colocate_nccl_fsdp_vllm",
"x_colocate_nccl_fsdp_vllm",
"no_colocate_gloo_fsdp_vllm",
"colocate_gloo_fsdp_vllm",
"no_colocate_nccl_fsdp2_vllm",
"colocate_nccl_fsdp2_vllm",
"x_colocate_nccl_fsdp2_vllm",
"colocate_nccl_fsdp2_vllm_mp",
"non_colocated_nccl_fsdp2_vllm_mp",
"colocate_nccl_fsdp2_vllm_dp",
Expand Down Expand Up @@ -116,6 +117,8 @@ async def test_policy_local_engines_e2e(
) as engines:
client, pg = engines.client, engines.pg

await client.sleep()

policy = init_worker_with_type(
"policy",
shared_pg=pg,
Expand All @@ -131,6 +134,11 @@ async def test_policy_local_engines_e2e(
"pass_through", "init_weight_sync_state", client, cfg.generator.inference_engine
)
)
# if colocate_all:
# await client.wake_up(tags=["weights", "kv_cache"])
# else:
# await client.wake_up(tags=["weights"])
await client.wake_up(tags=["weights"])
await client.reset_prefix_cache()
ray.get(
policy.async_run_ray_method(
Expand All @@ -139,7 +147,7 @@ async def test_policy_local_engines_e2e(
)
if colocate_all:
policy.offload_to_cpu()
await client.wake_up()
await client.wake_up(tags=["kv_cache"])

sampling_params = get_sampling_params_for_backend(
cfg.generator.inference_engine.backend, cfg.generator.sampling_params
Expand Down