Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ tinker = [
"psycopg2-binary",
]

ray = [
"ray[default]==2.51.1",
]

aws = [
"cloudpathlib[s3]",
]
Expand Down
17 changes: 17 additions & 0 deletions skyrl/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
uv run -m skyrl.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1
"""

import json
import time
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -112,6 +113,22 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
default=None,
description="Total number of processes in the multi-node cluster",
)
# RayJaxBackend configuration
use_ray: bool = Field(
default=False,
description="Use Ray to schedule JAX workers",
)

ray_actor_options: dict = Field(
default_factory=dict,
description="Options to pass to Ray actors (e.g., resources, num_cpus)",
json_schema_extra={"argparse_type": json.loads},
)
ray_pg_bundles: list = Field(
default_factory=list,
description="Bundles for the Ray placement group (e.g., [{'CPU': 1}] * num_processes)",
json_schema_extra={"argparse_type": json.loads},
)


@jax.tree_util.register_dataclass
Expand Down
192 changes: 192 additions & 0 deletions skyrl/backends/ray_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import socket

import ray
from cloudpathlib import AnyPath

from skyrl.backends.backend import AbstractBackend
from skyrl.backends.jax import JaxBackendConfig, JaxBackendImpl
from skyrl.tinker import types
from skyrl.utils.log import logger


def _get_random_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might cause a race condition if multiple processes do something like this at the same time (get a port, then release the port and re-use the port number). It might be more robust to either try a few random ports from a range and return a free one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the race condition will happen because _get_random_port is only called by index 0 JAX worker during initialization.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, happy to merge it! If there is any other process on the machine that allocates ports, it can happen, but we can also fix it if it actually happens :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm good point, I didn't consider that. I can open a follow-up PR for this :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #1652

return s.getsockname()[1]


@ray.remote
class RayJaxBackendImpl:
"""RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl.

Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl.
"""

def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
self.base_model = base_model
self.config = config.model_copy()
self.process_id = process_id
self.backend = None

if process_id == 0:
self.node_ip = ray.util.get_node_ip_address()
self.port = _get_random_port()
self.coordinator_address = f"{self.node_ip}:{self.port}"
else:
self.coordinator_address = None

def get_coordinator_address(self) -> str:
return self.coordinator_address

def setup(self, coordinator_address: str | None = None):
"""Initializes JAX distributed and creates JaxBackendImpl."""
import jax

addr = coordinator_address or self.coordinator_address
logger.info(f"Worker {self.process_id} initializing JAX distributed with coordinator {addr}")

jax.distributed.initialize(
coordinator_address=addr,
num_processes=self.config.num_processes,
process_id=self.process_id,
)
self.backend = JaxBackendImpl(self.base_model, self.config, self.process_id)
logger.info(f"Worker {self.process_id} JaxBackendImpl initialized.")

def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
self.backend.create_model(model_id, lora_config)

def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
return self.backend.forward_backward(prepared_batch)

def forward(self, prepared_batch: types.PreparedModelPassBatch):
return self.backend.forward(prepared_batch)

def optim_step(self, model_id: str, request_data: types.OptimStepInput):
return self.backend.optim_step(model_id, request_data)

def sample(self, prepared_batch: types.PreparedSampleBatch):
return self.backend.sample(prepared_batch)

def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None:
self.backend.save_checkpoint(output_path, model_id)

def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None:
self.backend.load_checkpoint(checkpoint_path, model_id)

def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
self.backend.save_sampler_checkpoint(output_path, model_id, persist)

def has_model(self, model_id: str) -> bool:
return self.backend.has_model(model_id)

def delete_model(self, model_id: str) -> None:
self.backend.delete_model(model_id)

def get_metrics(self) -> types.EngineMetrics:
return self.backend.metrics


class RayJaxBackend(AbstractBackend):
"""RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution.

This class runs in the driver program (along with Tinker API / Engine) and proxies
commands to all JAX workers running as Ray actors.
"""

def __init__(self, base_model: str, config: JaxBackendConfig):
self.base_model = base_model
self.config = config.model_copy()

if not self.config.num_processes:
raise ValueError("num_processes must be specified and > 0 when using Ray JAX backend")
num_processes = self.config.num_processes

logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}")

# Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig
from ray.util.placement_group import placement_group

logger.info("Creating Ray placement group for JAX backend")
bundles = self.config.ray_pg_bundles
if not bundles:
bundles = [{"CPU": 1}] * num_processes
Comment on lines +111 to +113
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When ray_pg_bundles is provided in the configuration, its length should be validated to ensure it matches num_processes. If the user provides fewer bundles than processes, the actor initialization loop will fail with an index error when trying to access bundles[i] during placement group bundle assignment.

Suggested change
bundles = self.config.ray_pg_bundles
if not bundles:
bundles = [{"CPU": 1}] * num_processes
bundles = self.config.ray_pg_bundles
if bundles and len(bundles) != num_processes:
raise ValueError(f"Number of bundles in ray_pg_bundles ({len(bundles)}) must match num_processes ({num_processes})")
if not bundles:
bundles = [{"CPU": 1}] * num_processes

self.pg = placement_group(bundles, strategy="SPREAD")
ray.get(self.pg.ready())

self.workers = []

# node0 (coordinator)
logger.info("Scheduling Ray actor for node0 (JAX coordinator)")
w0_options = self.config.ray_actor_options.copy()
w0_options.update(
{
"placement_group": self.pg,
"placement_group_bundle_index": 0,
}
)
w0 = RayJaxBackendImpl.options(**w0_options).remote(self.base_model, self.config, 0)
self.workers.append(w0)

coordinator_address = ray.get(w0.get_coordinator_address.remote())

# Create remaining node1 - nodeN for multi-node training.
logger.info("Scheduling remaining Ray actors (JAX workers)")
for i in range(1, num_processes):
wi_options = self.config.ray_actor_options.copy()
wi_options.update(
{
"placement_group": self.pg,
"placement_group_bundle_index": i,
}
)
w = RayJaxBackendImpl.options(**wi_options).remote(self.base_model, self.config, i)
self.workers.append(w)

# This will block until JAX distributed is initialized on all workers.
setup_refs = [w0.setup.remote()]
for w in self.workers[1:]:
setup_refs.append(w.setup.remote(coordinator_address))

ray.get(setup_refs)
logger.info("RayJaxBackend is fully initialized and distributed JAX cluster is ready.")

@property
def metrics(self) -> types.EngineMetrics:
return ray.get(self.workers[0].get_metrics.remote())

def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
ray.get([w.create_model.remote(model_id, lora_config) for w in self.workers])

def delete_model(self, model_id: str) -> None:
ray.get([w.delete_model.remote(model_id) for w in self.workers])

def has_model(self, model_id: str) -> bool:
return ray.get(self.workers[0].has_model.remote(model_id))

def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
results = ray.get([w.forward_backward.remote(prepared_batch) for w in self.workers])
Comment thread
andrewsykim marked this conversation as resolved.
return results[0]

def forward(self, prepared_batch: types.PreparedModelPassBatch):
results = ray.get([w.forward.remote(prepared_batch) for w in self.workers])
return results[0]

def optim_step(
self, model_id: str, request_data: types.OptimStepInput
) -> types.OptimStepOutput | types.ErrorResponse:
results = ray.get([w.optim_step.remote(model_id, request_data) for w in self.workers])
return results[0]

def sample(self, prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]:
results = ray.get([w.sample.remote(prepared_batch) for w in self.workers])
return results[0]

def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None:
ray.get([w.load_checkpoint.remote(checkpoint_path, model_id) for w in self.workers])

def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None:
ray.get([w.save_checkpoint.remote(output_path, model_id) for w in self.workers])

def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
ray.get([w.save_sampler_checkpoint.remote(output_path, model_id, persist) for w in self.workers])
Comment on lines +191 to +192
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Missing probe file write in RayJaxBackend.save_sampler_checkpoint causes concurrent writes on shared filesystems

RayJaxBackend.save_sampler_checkpoint dispatches save_sampler_checkpoint to all Ray actors without first writing a probe file. The existing JaxBackend.save_sampler_checkpoint (skyrl/backends/jax.py:1146-1150) writes a .probe file so that non-coordinator workers can detect a shared filesystem and skip redundant writes (see skyrl/utils/storage.py:29). Without this probe, all Ray actors will concurrently write to the same output_path via pack_and_upload, which performs non-atomic file I/O (skyrl/utils/storage.py:35-39), leading to file corruption on shared filesystems.

Prompt for agents
In RayJaxBackend.save_sampler_checkpoint (skyrl/backends/ray_jax.py:191-192), a probe file is not written before dispatching the save to all Ray actors. The existing JaxBackend (skyrl/backends/jax.py:1146-1150) writes a probe file at output_path.with_name(output_path.name + ".probe") so workers can detect shared filesystems and skip redundant writes (see skyrl/utils/storage.py:28-31 and the pack_and_upload context manager).

The fix should mirror what JaxBackend.save_sampler_checkpoint does: before dispatching to workers, the driver should create the parent directory and write the probe file. Something like:
  output_path.parent.mkdir(parents=True, exist_ok=True)
  output_path.with_name(output_path.name + ".probe").write_text("write_probe")

Note that the RayJaxBackend driver may not have access to the same filesystem as the actors (since it's the Ray driver process), so the correct approach may need to be adapted for the Ray execution model. Consider having the rank-0 actor write the probe before the other actors proceed.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

15 changes: 11 additions & 4 deletions skyrl/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,18 @@ def prepare_model_pass_batch(
)


def get_backend_classes(backend_name: str):
def get_backend_classes(backend_name: str, use_ray: bool = False):
"""Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
if backend_name == "jax":
from skyrl.backends.jax import JaxBackend, JaxBackendConfig
if use_ray:
Comment on lines 160 to +161
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backend name ray_jax (or ray-jax) should be explicitly supported in the if condition to match the updated error message and the usage shown in the PR description. Currently, the logic only triggers the Ray backend if the name is exactly "jax" and the use_ray flag is set in the configuration, which contradicts the suggested command-line usage.

Suggested change
if backend_name == "jax":
from skyrl.backends.jax import JaxBackend, JaxBackendConfig
if use_ray:
if backend_name in ["jax", "ray_jax", "ray-jax"]:
if use_ray or backend_name != "jax":

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the error message

from skyrl.backends.jax import JaxBackendConfig
from skyrl.backends.ray_jax import RayJaxBackend

return JaxBackend, JaxBackendConfig
return RayJaxBackend, JaxBackendConfig
else:
from skyrl.backends.jax import JaxBackend, JaxBackendConfig

return JaxBackend, JaxBackendConfig
elif backend_name == "fsdp":
from skyrl.backends.skyrl_train_backend import (
FSDPBackendOverrides,
Expand Down Expand Up @@ -236,7 +242,8 @@ def __init__(
enable_sqlite_wal(self.db_engine)

# Initialize the backend (handles model state, computation, and adapter management)
backend_class, backend_config_class = get_backend_classes(config.backend)
use_ray = config.backend_config.get("use_ray", False)
backend_class, backend_config_class = get_backend_classes(config.backend, use_ray=use_ray)
backend_config = backend_config_class(**config.backend_config)
self.backend = backend_class(config.base_model, backend_config)

Expand Down
Loading