Skip to content

[examples][infra] Anyscale 2-node Qwen3 launch flow#1689

Draft
xyuzh wants to merge 3 commits into
NovaSky-AI:mainfrom
xyuzh:xinyu/anyscale-launch
Draft

[examples][infra] Anyscale 2-node Qwen3 launch flow#1689
xyuzh wants to merge 3 commits into
NovaSky-AI:mainfrom
xyuzh:xinyu/anyscale-launch

Conversation

@xyuzh
Copy link
Copy Markdown

@xyuzh xyuzh commented May 18, 2026

⚠️ Stacked on top of #1688 ([workers] Clean teardown on SIGTERM). Merge that one first; the diff here will then be just the new files under examples/train/megatron/.

Why

Running SkyRL Megatron training on Anyscale (or any Ray-on-k8s cluster) with the default flow has a hidden cost:

  • The driver launches via uv run --extra megatron ….
  • Ray's uv runtime-env hook auto-detects that the driver was started by uv run and propagates the same uv run command as py_executable for every Ray worker actor.
  • Each actor — not each node, each actor — independently runs uv install of skyrl[megatron], building transformer-engine-torch from source (~5–10 min per actor).
  • With TP=4 × DP=4, that's 16 actors all doing the same 5–10 min build. The training process group then waits for init_process_group on all 16, hits the 600 s NCCL rendezvous timeout, and dies with:
torch.distributed.DistStoreError: Timed out after 601 seconds waiting for clients. 1/16 clients joined.

What this PR adds

A complete Anyscale launch flow for Qwen3-30B-A3B and Qwen3-235B-A22B on 2 H100 nodes, built around a single shared venv on NFS:

  1. build_shared_venv.yaml — one-shot job that runs uv sync --extra megatron --no-editable into /mnt/cluster_storage/.skyrl-venv. Idempotent; subsequent invocations no-op.
  2. anyscale_qwen3_30b_2nodes.yaml / anyscale_qwen3_235b_2nodes.yaml — training-job YAMLs that set py_executable: /mnt/cluster_storage/.skyrl-venv/bin/python so every Ray actor reuses the prebuilt venv directly.
  3. run_megatron_qwen3_30b_2nodes.sh / run_megatron_qwen3_235b_2nodes.sh — the training shell scripts. They unset RAY_RUNTIME_ENV_HOOK before invoking Python so the auto uv run propagation doesn't reactivate, and forward SIGTERM to the python child so the in-process handler (from [workers] Clean teardown on SIGTERM: drain CUDA + destroy process group #1688) actually fires.
  4. Operational helpersclear_uv_cache.yaml, clear_venv.yaml, dump_diag.yaml, inspect_cudnn.yaml, download_model.yaml. Useful when iterating on the cluster (rebuild venv, inspect cudnn paths, prefetch the HF model into /mnt/cluster_storage/hf_cache).

Effect

Validated end-to-end on Qwen3-30B-A3B (TP=4, EP=8, DP=4 — 16 actors, 2 nodes × 8 H100s):

  • Cluster startup → Megatron init → vLLM init → first training step: ~15 min (vs. ~45 min cold without the shared venv, when actors were stuck rebuilding TE).
  • Multiple successful training steps with sensible metrics — policy_loss 0.168 → 0.151 over 5 steps, reward/avg_pass_at_16 trending up from 0.447 to 0.490, full eval-before-train cycle completed.
  • The 235B YAML uses the same pattern; only the model path, parallelism, and FP8 quant flags differ.

Test plan

  • anyscale job submit -f examples/train/megatron/build_shared_venv.yaml --wait once per cluster lifetime — confirm IMPORTS OK in the final log line.
  • anyscale job submit -f examples/train/megatron/anyscale_qwen3_30b_2nodes.yaml — confirm the worker raylet logs show no Building transformer-engine-torch and the job reaches Initialized process group for RayActorGroup within minutes.
  • Run for ≥1 training step, confirm metrics log (policy_loss, reward/avg_pass_at_16, timing/step).
  • (Optional) repeat with the 235B YAML once the shared venv is built — same flow, no rebuild.

xyuzh added 3 commits May 18, 2026 12:08
…cess group

When a k8s pod is evicted (preemption, scale-down, node drain) the container
gets SIGTERM with a 25s grace period before SIGKILL. Without a handler, in-flight
NCCL collectives leak communicators and the next run may hit stale process group
state.

Add a SIGTERM handler inside DistributedTorchRayActor.init_worker_process_group()
that:
  - calls torch.cuda.synchronize() to drain any in-flight CUDA work
  - calls torch.distributed.destroy_process_group() to release NCCL
  - exits cleanly with sys.exit(0)

Both calls are wrapped in try/except so a partial-state worker still tears down
the half that's healthy. The whole sequence is well under the 25s grace window.

Each call is guarded (`torch.distributed.is_available() and is_initialized()`)
so it does nothing when distributed isn't set up yet.
Skip the call entirely on CPU-only environments so we don't generate a noisy
warning every time a non-CUDA worker is terminated. Only emit a warning if
synchronize() actually fails on a CUDA-capable system.
Adds an Anyscale-friendly launch path for Qwen3-30B-A3B and Qwen3-235B-A22B
on 2 H100 nodes. The key piece is a one-shot venv build on shared NFS that
all Ray actors share via py_executable — without it every actor independently
ran uv install of skyrl[megatron], causing NCCL rendezvous to time out before
all 16 actors registered.

Built on top of the SIGTERM teardown PR (NovaSky-AI#1688) so preemption / scale-down
doesn't leak NCCL communicators.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces configuration and scripts for training and validating large-scale Qwen3 models (235B and 30B) using Megatron and Ray on Anyscale. Key additions include shared virtual environment management, model prefetching, and optimized training scripts utilizing FP8 and optimizer offloading. Feedback highlights several critical issues: the SIGTERM cleanup handler in the worker base class is bypassed by Megatron worker overrides, model-copying logic for local NVMe caching is misplaced on the head node instead of the workers, and there are potential race conditions when building the shared virtual environment or installing NVIDIA packages on shared NFS without proper locking or atomicity.

Comment on lines +126 to +147
# Clean teardown on k8s SIGTERM: drain CUDA streams + release NCCL
# communicators before the 25s grace period elapses.
rank = self._rank

def _sigterm_cleanup(signum, frame):
logger.warning(f"SIGTERM received in worker rank={rank}, cleaning up...")

if torch.cuda.is_available():
try:
torch.cuda.synchronize()
except Exception as e:
logger.warning(f"cuda.synchronize() failed: {e}")

try:
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
except Exception as e:
logger.warning(f"destroy_process_group() failed: {e}")

sys.exit(0)

signal.signal(signal.SIGTERM, _sigterm_cleanup)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SIGTERM cleanup handler is registered within init_worker_process_group. However, MegatronPolicyWorkerBase (and other Megatron-specific workers in megatron_worker.py) overrides this method without calling super().init_worker_process_group(). Consequently, this cleanup logic will not be active for Megatron training runs, which are the primary focus of this launch flow.

To ensure the handler is registered, you should either:

  1. Call super().init_worker_process_group() in the Megatron worker subclasses.
  2. Move the signal registration to a method that is guaranteed to run, such as the __init__ method of DistributedTorchRayActor.

Comment on lines +23 to +26
if [ -d "$NFS_HF/hub/models--Qwen--Qwen3-30B-A3B-Base" ] && [ ! -d "$LOCAL_HF/hub/models--Qwen--Qwen3-30B-A3B-Base" ]; then
echo "[script] copying Qwen3-30B-A3B-Base from NFS to local"
cp -r "$NFS_HF/hub/models--Qwen--Qwen3-30B-A3B-Base" "$LOCAL_HF/hub/"
fi
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This model-copying logic executes on the head node as part of the job's entrypoint. However, in the provided Anyscale configuration, the head node does not have GPUs and does not run training workers. The actual Ray workers on the worker nodes will not execute this shell script; they are started directly via py_executable.

As a result, the workers will continue to use the NFS path specified in the YAML's HF_HOME, and the intended optimization of using local NVMe for faster loading will not be realized on the nodes where it matters most. Consider implementing the local caching logic within the worker's initialization or using a Ray runtime environment setup hook that runs on all nodes.

Comment on lines +12 to +13
echo "[setup] building venv at $VENV ($(date -u))"
rm -rf "$VENV"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential race condition if multiple build jobs are triggered simultaneously or if a job is retried. The rm -rf "$VENV" followed by a non-atomic uv sync can leave the shared storage in a corrupted state for other processes.

It is safer to build the environment in a temporary directory and then use an atomic mv to place it at the final destination once the build is successful (and the stamp is created).

Comment on lines +50 to +57
if [ ! -f "$NV/nvidia/cudnn/include/cudnn.h" ]; then
echo "[bootstrap] installing nvidia headers to $NV" | tee -a "$DIAG_LOG"
/home/ray/anaconda3/bin/python3 -m pip install -q --target "$NV" \
nvidia-cudnn-cu12 nvidia-nccl-cu12 nvidia-cublas-cu12 nvidia-cusparse-cu12 \
nvidia-cusolver-cu12 nvidia-curand-cu12 nvidia-cufft-cu12 nvidia-cuda-runtime-cu12 \
nvidia-cuda-nvrtc-cu12 nvidia-cuda-cupti-cu12 nvidia-nvjitlink-cu12 nvidia-nvtx-cu12 \
nvidia-cudnn-frontend
fi
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Installing nvidia packages directly to a shared NFS path without locking is prone to race conditions if multiple jobs are launched. A partial installation from one job could cause others to fail or use corrupted headers/libraries.

Consider using a more robust synchronization method or ensuring that this bootstrap step is performed by a single, idempotent setup job (similar to the venv build).

@xyuzh xyuzh marked this pull request as draft May 19, 2026 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant