-
Notifications
You must be signed in to change notification settings - Fork 331
[tx] Add initial implementation of RayJaxBackend #1418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b6d2214
8880aa3
0865d83
5103752
f719ce5
2d755ff
097cc04
e82a309
1abea0b
0a85a44
e5b3778
6bdb799
2748609
7a06bba
c494706
25db033
febcb9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,10 @@ tinker = [ | |
| "psycopg2-binary", | ||
| ] | ||
|
|
||
| ray = [ | ||
| "ray[default]==2.51.1", | ||
| ] | ||
|
|
||
| aws = [ | ||
| "cloudpathlib[s3]", | ||
| ] | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,192 @@ | ||||||||||||||||||
| import socket | ||||||||||||||||||
|
|
||||||||||||||||||
| import ray | ||||||||||||||||||
| from cloudpathlib import AnyPath | ||||||||||||||||||
|
|
||||||||||||||||||
| from skyrl.backends.backend import AbstractBackend | ||||||||||||||||||
| from skyrl.backends.jax import JaxBackendConfig, JaxBackendImpl | ||||||||||||||||||
| from skyrl.tinker import types | ||||||||||||||||||
| from skyrl.utils.log import logger | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _get_random_port() -> int: | ||||||||||||||||||
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | ||||||||||||||||||
| s.bind(("", 0)) | ||||||||||||||||||
| return s.getsockname()[1] | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @ray.remote | ||||||||||||||||||
| class RayJaxBackendImpl: | ||||||||||||||||||
| """RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl. | ||||||||||||||||||
|
|
||||||||||||||||||
| Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int): | ||||||||||||||||||
| self.base_model = base_model | ||||||||||||||||||
| self.config = config.model_copy() | ||||||||||||||||||
| self.process_id = process_id | ||||||||||||||||||
| self.backend = None | ||||||||||||||||||
|
|
||||||||||||||||||
| if process_id == 0: | ||||||||||||||||||
| self.node_ip = ray.util.get_node_ip_address() | ||||||||||||||||||
| self.port = _get_random_port() | ||||||||||||||||||
| self.coordinator_address = f"{self.node_ip}:{self.port}" | ||||||||||||||||||
| else: | ||||||||||||||||||
| self.coordinator_address = None | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_coordinator_address(self) -> str: | ||||||||||||||||||
| return self.coordinator_address | ||||||||||||||||||
|
|
||||||||||||||||||
| def setup(self, coordinator_address: str | None = None): | ||||||||||||||||||
| """Initializes JAX distributed and creates JaxBackendImpl.""" | ||||||||||||||||||
| import jax | ||||||||||||||||||
|
|
||||||||||||||||||
| addr = coordinator_address or self.coordinator_address | ||||||||||||||||||
| logger.info(f"Worker {self.process_id} initializing JAX distributed with coordinator {addr}") | ||||||||||||||||||
|
|
||||||||||||||||||
| jax.distributed.initialize( | ||||||||||||||||||
| coordinator_address=addr, | ||||||||||||||||||
| num_processes=self.config.num_processes, | ||||||||||||||||||
| process_id=self.process_id, | ||||||||||||||||||
| ) | ||||||||||||||||||
| self.backend = JaxBackendImpl(self.base_model, self.config, self.process_id) | ||||||||||||||||||
| logger.info(f"Worker {self.process_id} JaxBackendImpl initialized.") | ||||||||||||||||||
|
|
||||||||||||||||||
| def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: | ||||||||||||||||||
| self.backend.create_model(model_id, lora_config) | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): | ||||||||||||||||||
| return self.backend.forward_backward(prepared_batch) | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward(self, prepared_batch: types.PreparedModelPassBatch): | ||||||||||||||||||
| return self.backend.forward(prepared_batch) | ||||||||||||||||||
|
|
||||||||||||||||||
| def optim_step(self, model_id: str, request_data: types.OptimStepInput): | ||||||||||||||||||
| return self.backend.optim_step(model_id, request_data) | ||||||||||||||||||
|
|
||||||||||||||||||
| def sample(self, prepared_batch: types.PreparedSampleBatch): | ||||||||||||||||||
| return self.backend.sample(prepared_batch) | ||||||||||||||||||
|
|
||||||||||||||||||
| def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None: | ||||||||||||||||||
| self.backend.save_checkpoint(output_path, model_id) | ||||||||||||||||||
|
|
||||||||||||||||||
| def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: | ||||||||||||||||||
| self.backend.load_checkpoint(checkpoint_path, model_id) | ||||||||||||||||||
|
|
||||||||||||||||||
| def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None: | ||||||||||||||||||
| self.backend.save_sampler_checkpoint(output_path, model_id, persist) | ||||||||||||||||||
|
|
||||||||||||||||||
| def has_model(self, model_id: str) -> bool: | ||||||||||||||||||
| return self.backend.has_model(model_id) | ||||||||||||||||||
|
|
||||||||||||||||||
| def delete_model(self, model_id: str) -> None: | ||||||||||||||||||
| self.backend.delete_model(model_id) | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_metrics(self) -> types.EngineMetrics: | ||||||||||||||||||
| return self.backend.metrics | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class RayJaxBackend(AbstractBackend): | ||||||||||||||||||
| """RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution. | ||||||||||||||||||
|
|
||||||||||||||||||
| This class runs in the driver program (along with Tinker API / Engine) and proxies | ||||||||||||||||||
| commands to all JAX workers running as Ray actors. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, base_model: str, config: JaxBackendConfig): | ||||||||||||||||||
| self.base_model = base_model | ||||||||||||||||||
| self.config = config.model_copy() | ||||||||||||||||||
|
|
||||||||||||||||||
| if not self.config.num_processes: | ||||||||||||||||||
| raise ValueError("num_processes must be specified and > 0 when using Ray JAX backend") | ||||||||||||||||||
| num_processes = self.config.num_processes | ||||||||||||||||||
|
|
||||||||||||||||||
| logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}") | ||||||||||||||||||
|
|
||||||||||||||||||
| # Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig | ||||||||||||||||||
| from ray.util.placement_group import placement_group | ||||||||||||||||||
|
|
||||||||||||||||||
| logger.info("Creating Ray placement group for JAX backend") | ||||||||||||||||||
| bundles = self.config.ray_pg_bundles | ||||||||||||||||||
| if not bundles: | ||||||||||||||||||
| bundles = [{"CPU": 1}] * num_processes | ||||||||||||||||||
|
Comment on lines
+111
to
+113
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Suggested change
|
||||||||||||||||||
| self.pg = placement_group(bundles, strategy="SPREAD") | ||||||||||||||||||
| ray.get(self.pg.ready()) | ||||||||||||||||||
|
|
||||||||||||||||||
| self.workers = [] | ||||||||||||||||||
|
|
||||||||||||||||||
| # node0 (coordinator) | ||||||||||||||||||
| logger.info("Scheduling Ray actor for node0 (JAX coordinator)") | ||||||||||||||||||
| w0_options = self.config.ray_actor_options.copy() | ||||||||||||||||||
| w0_options.update( | ||||||||||||||||||
| { | ||||||||||||||||||
| "placement_group": self.pg, | ||||||||||||||||||
| "placement_group_bundle_index": 0, | ||||||||||||||||||
| } | ||||||||||||||||||
| ) | ||||||||||||||||||
| w0 = RayJaxBackendImpl.options(**w0_options).remote(self.base_model, self.config, 0) | ||||||||||||||||||
| self.workers.append(w0) | ||||||||||||||||||
|
|
||||||||||||||||||
| coordinator_address = ray.get(w0.get_coordinator_address.remote()) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Create remaining node1 - nodeN for multi-node training. | ||||||||||||||||||
| logger.info("Scheduling remaining Ray actors (JAX workers)") | ||||||||||||||||||
| for i in range(1, num_processes): | ||||||||||||||||||
| wi_options = self.config.ray_actor_options.copy() | ||||||||||||||||||
| wi_options.update( | ||||||||||||||||||
| { | ||||||||||||||||||
| "placement_group": self.pg, | ||||||||||||||||||
| "placement_group_bundle_index": i, | ||||||||||||||||||
| } | ||||||||||||||||||
| ) | ||||||||||||||||||
| w = RayJaxBackendImpl.options(**wi_options).remote(self.base_model, self.config, i) | ||||||||||||||||||
| self.workers.append(w) | ||||||||||||||||||
|
|
||||||||||||||||||
| # This will block until JAX distributed is initialized on all workers. | ||||||||||||||||||
| setup_refs = [w0.setup.remote()] | ||||||||||||||||||
| for w in self.workers[1:]: | ||||||||||||||||||
| setup_refs.append(w.setup.remote(coordinator_address)) | ||||||||||||||||||
|
|
||||||||||||||||||
| ray.get(setup_refs) | ||||||||||||||||||
| logger.info("RayJaxBackend is fully initialized and distributed JAX cluster is ready.") | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def metrics(self) -> types.EngineMetrics: | ||||||||||||||||||
| return ray.get(self.workers[0].get_metrics.remote()) | ||||||||||||||||||
|
|
||||||||||||||||||
| def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: | ||||||||||||||||||
| ray.get([w.create_model.remote(model_id, lora_config) for w in self.workers]) | ||||||||||||||||||
|
|
||||||||||||||||||
| def delete_model(self, model_id: str) -> None: | ||||||||||||||||||
| ray.get([w.delete_model.remote(model_id) for w in self.workers]) | ||||||||||||||||||
|
|
||||||||||||||||||
| def has_model(self, model_id: str) -> bool: | ||||||||||||||||||
| return ray.get(self.workers[0].has_model.remote(model_id)) | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): | ||||||||||||||||||
| results = ray.get([w.forward_backward.remote(prepared_batch) for w in self.workers]) | ||||||||||||||||||
|
andrewsykim marked this conversation as resolved.
|
||||||||||||||||||
| return results[0] | ||||||||||||||||||
|
|
||||||||||||||||||
| def forward(self, prepared_batch: types.PreparedModelPassBatch): | ||||||||||||||||||
| results = ray.get([w.forward.remote(prepared_batch) for w in self.workers]) | ||||||||||||||||||
| return results[0] | ||||||||||||||||||
|
|
||||||||||||||||||
| def optim_step( | ||||||||||||||||||
| self, model_id: str, request_data: types.OptimStepInput | ||||||||||||||||||
| ) -> types.OptimStepOutput | types.ErrorResponse: | ||||||||||||||||||
| results = ray.get([w.optim_step.remote(model_id, request_data) for w in self.workers]) | ||||||||||||||||||
| return results[0] | ||||||||||||||||||
|
|
||||||||||||||||||
| def sample(self, prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]: | ||||||||||||||||||
| results = ray.get([w.sample.remote(prepared_batch) for w in self.workers]) | ||||||||||||||||||
| return results[0] | ||||||||||||||||||
|
|
||||||||||||||||||
| def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: | ||||||||||||||||||
| ray.get([w.load_checkpoint.remote(checkpoint_path, model_id) for w in self.workers]) | ||||||||||||||||||
|
|
||||||||||||||||||
| def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None: | ||||||||||||||||||
| ray.get([w.save_checkpoint.remote(output_path, model_id) for w in self.workers]) | ||||||||||||||||||
|
|
||||||||||||||||||
| def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None: | ||||||||||||||||||
| ray.get([w.save_sampler_checkpoint.remote(output_path, model_id, persist) for w in self.workers]) | ||||||||||||||||||
|
Comment on lines
+191
to
+192
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Missing probe file write in RayJaxBackend.save_sampler_checkpoint causes concurrent writes on shared filesystems
Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -155,12 +155,18 @@ def prepare_model_pass_batch( | |||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def get_backend_classes(backend_name: str): | ||||||||||||
| def get_backend_classes(backend_name: str, use_ray: bool = False): | ||||||||||||
| """Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray).""" | ||||||||||||
| if backend_name == "jax": | ||||||||||||
| from skyrl.backends.jax import JaxBackend, JaxBackendConfig | ||||||||||||
| if use_ray: | ||||||||||||
|
Comment on lines
160
to
+161
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The backend name
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated the error message |
||||||||||||
| from skyrl.backends.jax import JaxBackendConfig | ||||||||||||
| from skyrl.backends.ray_jax import RayJaxBackend | ||||||||||||
|
|
||||||||||||
| return JaxBackend, JaxBackendConfig | ||||||||||||
| return RayJaxBackend, JaxBackendConfig | ||||||||||||
| else: | ||||||||||||
| from skyrl.backends.jax import JaxBackend, JaxBackendConfig | ||||||||||||
|
|
||||||||||||
| return JaxBackend, JaxBackendConfig | ||||||||||||
| elif backend_name == "fsdp": | ||||||||||||
| from skyrl.backends.skyrl_train_backend import ( | ||||||||||||
| FSDPBackendOverrides, | ||||||||||||
|
|
@@ -236,7 +242,8 @@ def __init__( | |||||||||||
| enable_sqlite_wal(self.db_engine) | ||||||||||||
|
|
||||||||||||
| # Initialize the backend (handles model state, computation, and adapter management) | ||||||||||||
| backend_class, backend_config_class = get_backend_classes(config.backend) | ||||||||||||
| use_ray = config.backend_config.get("use_ray", False) | ||||||||||||
| backend_class, backend_config_class = get_backend_classes(config.backend, use_ray=use_ray) | ||||||||||||
| backend_config = backend_config_class(**config.backend_config) | ||||||||||||
| self.backend = backend_class(config.base_model, backend_config) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might cause a race condition if multiple processes do something like this at the same time (get a port, then release the port and re-use the port number). It might be more robust to either try a few random ports from a range and return a free one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the race condition will happen because
_get_random_portis only called by index 0 JAX worker during initialization.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, happy to merge it! If there is any other process on the machine that allocates ports, it can happen, but we can also fix it if it actually happens :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm good point, I didn't consider that. I can open a follow-up PR for this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened #1652