From b6d22146f8a3c58ecb02e2ddb588c753e68891ce Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 27 Mar 2026 02:14:21 +0000 Subject: [PATCH 01/17] Add initial actor classes and entrypoint script Signed-off-by: Andrew Sy Kim --- skyrl/backends/ray_jax.py | 161 +++++++++++++++++++++++++++++++ skyrl/tinker/api.py | 96 +++++++++--------- skyrl/tinker/config.py | 9 ++ skyrl/tinker/engine.py | 7 +- skyrl/tinker/entrypoints/main.py | 42 ++++++++ skyrl/tinker/ray_actors.py | 28 ++++++ 6 files changed, 298 insertions(+), 45 deletions(-) create mode 100644 skyrl/backends/ray_jax.py create mode 100644 skyrl/tinker/entrypoints/main.py create mode 100644 skyrl/tinker/ray_actors.py diff --git a/skyrl/backends/ray_jax.py b/skyrl/backends/ray_jax.py new file mode 100644 index 0000000000..70fb98205b --- /dev/null +++ b/skyrl/backends/ray_jax.py @@ -0,0 +1,161 @@ +import socket +import ray +from cloudpathlib import AnyPath + +from skyrl.backends.backend import AbstractBackend +from skyrl.backends.jax import JaxBackend, JaxBackendConfig, run_worker +from skyrl.tinker import types +from skyrl.utils.log import logger + +def get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +@ray.remote +class RayJaxCoordinatorActor: + """Ray Actor wrapper for the JaxBackend coordinator (process_id = 0). + + This actor dynamically allocates a port, provides its coordinator_address + to workers, and then blocks initializing jax.distributed until workers join. + """ + def __init__(self, base_model: str, config: JaxBackendConfig): + self.base_model = base_model + # Use model_copy so we don't accidentally mutate unintended shared state + self.config = config.model_copy() + + # Determine coordinator address + self.node_ip = ray.util.get_node_ip_address() + self.port = get_free_port() + self.coordinator_address = f"{self.node_ip}:{self.port}" + + # Update config with dynamically found address + self.config.coordinator_address = self.coordinator_address + self.backend = None + + def get_coordinator_address(self) -> str: + return self.coordinator_address + + def setup(self): + """Initializes the backend. Blocks until all workers connect to the coordinator.""" + self.backend = JaxBackend(self.base_model, self.config) + + # ========================================================================= + # Proxied Backend Methods + # ========================================================================= + + def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: + self.backend.create_model(model_id, lora_config) + + def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): + return self.backend.forward_backward(prepared_batch) + + def forward(self, prepared_batch: types.PreparedModelPassBatch): + return self.backend.forward(prepared_batch) + + def optim_step(self, model_id: str, request_data: types.OptimStepInput): + return self.backend.optim_step(model_id, request_data) + + def sample(self, prepared_batch: types.PreparedSampleBatch): + return self.backend.sample(prepared_batch) + + def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None: + self.backend.save_checkpoint(output_path, model_id) + + def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: + self.backend.load_checkpoint(checkpoint_path, model_id) + + def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None: + self.backend.save_sampler_checkpoint(output_path, model_id, persist) + + def has_model(self, model_id: str) -> bool: + return self.backend.has_model(model_id) + + def delete_model(self, model_id: str) -> None: + self.backend.delete_model(model_id) + + def get_metrics(self) -> types.EngineMetrics: + return self.backend.metrics + + +@ray.remote +class RayJaxWorkerActor: + """Ray Actor wrapper for JaxBackend workers (process_id > 0).""" + def __init__(self, coordinator_address: str, num_processes: int, process_id: int): + self.coordinator_address = coordinator_address + self.num_processes = num_processes + self.process_id = process_id + + def run(self): + """Run the worker loop infinitely.""" + run_worker(self.coordinator_address, self.num_processes, self.process_id) + + +class RayJaxBackend(AbstractBackend): + """Proxy Backend that orchestrates Ray actors for multi-node JAX execution. + + Locally, this class acts like a normal AbstractBackend. Internally, it creates + a RayJaxCoordinatorActor (which internally wraps JaxBackend) to execute work, + and dynamically provisions RayJaxWorkerActors matching `num_processes`. + """ + def __init__(self, base_model: str, config: JaxBackendConfig): + self.base_model = base_model + self.config = config.model_copy() + + num_processes = self.config.num_processes or 1 + self.config.num_processes = num_processes + + logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}") + + # Instantiate the coordinator but do not run setup yet (to avoid blocking) + self.actor = RayJaxCoordinatorActor.remote(self.base_model, self.config) + + # Retrieve dynamically allocated coordinator address from actor + coordinator_address = ray.get(self.actor.get_coordinator_address.remote()) + + self.worker_tasks = [] + if num_processes > 1: + for i in range(1, num_processes): + worker = RayJaxWorkerActor.remote(coordinator_address, num_processes, i) + self.worker_tasks.append(worker.run.remote()) + + # Trigger the coordinator setup, initializing JAX distributed. + # This will block until the workers connect successfully. + ray.get(self.actor.setup.remote()) + + logger.info("RayJaxBackend is fully initialized and distributed cluster is ready.") + + @property + def metrics(self) -> types.EngineMetrics: + return ray.get(self.actor.get_metrics.remote()) + + def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: + ray.get(self.actor.create_model.remote(model_id, lora_config)) + + def delete_model(self, model_id: str) -> None: + ray.get(self.actor.delete_model.remote(model_id)) + + def has_model(self, model_id: str) -> bool: + return ray.get(self.actor.has_model.remote(model_id)) + + def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): + return ray.get(self.actor.forward_backward.remote(prepared_batch)) + + def forward(self, prepared_batch: types.PreparedModelPassBatch): + return ray.get(self.actor.forward.remote(prepared_batch)) + + def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput | types.ErrorResponse: + return ray.get(self.actor.optim_step.remote(model_id, request_data)) + + def sample(self, prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]: + return ray.get(self.actor.sample.remote(prepared_batch)) + + def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: + ray.get(self.actor.load_checkpoint.remote(checkpoint_path, model_id)) + + def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None: + ray.get(self.actor.save_checkpoint.remote(output_path, model_id)) + + def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None: + ray.get(self.actor.save_sampler_checkpoint.remote(output_path, model_id, persist)) diff --git a/skyrl/tinker/api.py b/skyrl/tinker/api.py index f505be77a4..ecd5695236 100644 --- a/skyrl/tinker/api.py +++ b/skyrl/tinker/api.py @@ -120,54 +120,62 @@ async def lifespan(app: FastAPI): logger.info("Using internal engine for inference") # Build subprocess command with engine config parameters. - parent_cmd = psutil.Process(os.getppid()).cmdline() - cmd = _build_uv_run_cmd_engine(parent_cmd, app.state.engine_config) - - background_engine = await asyncio.create_subprocess_exec(*cmd) - app.state.background_engine = background_engine - logger.info(f"Started background engine with PID {background_engine.pid}: {' '.join(cmd)}") - - shutting_down = False - - async def monitor_engine(): - """Monitor engine process and exit API server if it crashes.""" - exit_code = await background_engine.wait() - if not shutting_down: - logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") - - # Start a background timer that force-exits after timeout. - # Using a thread instead of asyncio task because SIGTERM handling - # may wait for pending asyncio tasks to complete before exiting. - def force_exit(): - logger.warning("Graceful shutdown timed out, forcing exit") - os._exit(1) - - timer = threading.Timer(SHUTDOWN_TIMEOUT_SECONDS, force_exit) - timer.daemon = True - timer.start() - - # Request graceful shutdown. Uvicorn will stop accepting new - # connections and wait for active requests to complete. - # If shutdown doesn't complete in time, force_exit() will terminate. - os.kill(os.getpid(), signal.SIGTERM) - - monitor_task = asyncio.create_task(monitor_engine()) + if not app.state.engine_config.use_ray: + parent_cmd = psutil.Process(os.getppid()).cmdline() + cmd = _build_uv_run_cmd_engine(parent_cmd, app.state.engine_config) + + background_engine = await asyncio.create_subprocess_exec(*cmd) + app.state.background_engine = background_engine + logger.info(f"Started background engine with PID {background_engine.pid}: {' '.join(cmd)}") + + shutting_down = False + + async def monitor_engine(): + """Monitor engine process and exit API server if it crashes.""" + exit_code = await background_engine.wait() + if not shutting_down: + logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") + + # Start a background timer that force-exits after timeout. + # Using a thread instead of asyncio task because SIGTERM handling + # may wait for pending asyncio tasks to complete before exiting. + def force_exit(): + logger.warning("Graceful shutdown timed out, forcing exit") + os._exit(1) + + timer = threading.Timer(SHUTDOWN_TIMEOUT_SECONDS, force_exit) + timer.daemon = True + timer.start() + + # Request graceful shutdown. Uvicorn will stop accepting new + # connections and wait for active requests to complete. + # If shutdown doesn't complete in time, force_exit() will terminate. + os.kill(os.getpid(), signal.SIGTERM) + + monitor_task = asyncio.create_task(monitor_engine()) + else: + logger.info("Running in Ray orchestrated mode. Background engine will not be started here.") + shutting_down = False + monitor_task = None yield shutting_down = True - monitor_task.cancel() - - logger.info(f"Stopping background engine (PID {app.state.background_engine.pid})") - with suppress(ProcessLookupError): - background_engine.terminate() - try: - await asyncio.wait_for(background_engine.wait(), timeout=5) - except asyncio.TimeoutError: - logger.warning(f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing") - background_engine.kill() - await background_engine.wait() - logger.info("Background engine stopped") + if monitor_task: + monitor_task.cancel() + + if getattr(app.state, "background_engine", None): + logger.info(f"Stopping background engine (PID {app.state.background_engine.pid})") + with suppress(ProcessLookupError): + background_engine = app.state.background_engine + background_engine.terminate() + try: + await asyncio.wait_for(background_engine.wait(), timeout=5) + except asyncio.TimeoutError: + logger.warning(f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing") + background_engine.kill() + await background_engine.wait() + logger.info("Background engine stopped") app = FastAPI(title="Tinker API Mock", version="0.0.1", lifespan=lifespan) diff --git a/skyrl/tinker/config.py b/skyrl/tinker/config.py index fff7eab256..4b1168ee71 100644 --- a/skyrl/tinker/config.py +++ b/skyrl/tinker/config.py @@ -14,6 +14,15 @@ class EngineConfig(BaseModel): model_config = ConfigDict(extra="forbid") + use_ray: bool = Field( + default=False, + description="Whether to use Ray for orchestration (Ray Actors for components)", + ) + ray_address: str | None = Field( + default=None, + description="Address of an existing Ray cluster to connect to. If not set, Ray will be initialized locally.", + ) + base_model: str = Field(..., description="Base model name (e.g., Qwen/Qwen3-0.6B)") backend: str = Field(default="jax", description="Backend to use for training and inference") backend_config: dict = Field( diff --git a/skyrl/tinker/engine.py b/skyrl/tinker/engine.py index 6c449eb9f7..b0c4ffb777 100644 --- a/skyrl/tinker/engine.py +++ b/skyrl/tinker/engine.py @@ -175,9 +175,14 @@ def get_backend_classes(backend_name: str): ) return SkyRLTrainBackend, MegatronBackendOverrides + elif backend_name == "ray_jax": + from skyrl.backends.jax import JaxBackendConfig + from skyrl.backends.ray_jax import RayJaxBackend + + return RayJaxBackend, JaxBackendConfig else: raise ValueError( - f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. " + f"Unknown backend: {backend_name}. Available backends: jax, ray_jax, fsdp, megatron. " f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])" ) diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py new file mode 100644 index 0000000000..3fd94a1917 --- /dev/null +++ b/skyrl/tinker/entrypoints/main.py @@ -0,0 +1,42 @@ +import argparse +import ray + +from skyrl.tinker.config import EngineConfig, add_model +from skyrl.tinker.ray_actors import TinkerAPIActor, TinkerEngineActor +from skyrl.utils.log import logger + +def main(): + parser = argparse.ArgumentParser(description="SkyRL Tinker Ray Orchestrator") + add_model(parser, EngineConfig) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to for API Server") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to for API Server") + args = parser.parse_args() + + # Create EngineConfig from parsed arguments + config = EngineConfig.model_validate({k: v for k, v in vars(args).items() if k in EngineConfig.model_fields}) + + # Force Ray orchestrated mode and ray_jax backend + config.use_ray = True + config.backend = "ray_jax" + + logger.info(f"Initializing Ray with address: {config.ray_address or 'local'}") + ray.init(address=config.ray_address) + + logger.info(f"Starting Tinker API Actor on {args.host}:{args.port}") + api_actor = TinkerAPIActor.remote(config) + api_task = api_actor.run.remote(args.host, args.port) + + logger.info("Starting Tinker Engine Actor") + engine_actor = TinkerEngineActor.remote(config) + engine_task = engine_actor.run.remote() + + logger.info("Ray Orchestrator running. Waiting for actors to complete.") + try: + ray.get([api_task, engine_task]) + except KeyboardInterrupt: + logger.info("Interrupted. Shutting down Ray...") + ray.shutdown() + + +if __name__ == "__main__": + main() diff --git a/skyrl/tinker/ray_actors.py b/skyrl/tinker/ray_actors.py new file mode 100644 index 0000000000..ead654b9df --- /dev/null +++ b/skyrl/tinker/ray_actors.py @@ -0,0 +1,28 @@ +import ray +import uvicorn +from skyrl.tinker.config import EngineConfig +from skyrl.tinker.engine import TinkerEngine +from skyrl.tinker.api import app + +@ray.remote +class TinkerAPIActor: + """Ray Actor wrapper for the Tinker API server (FastAPI + Uvicorn).""" + def __init__(self, config: EngineConfig): + self.config = config + + def run(self, host: str, port: int): + app.state.engine_config = self.config + # Logging config can be customized if needed + from skyrl.utils.log import get_uvicorn_log_config + uvicorn.run(app, host=host, port=port, log_config=get_uvicorn_log_config()) + + +@ray.remote +class TinkerEngineActor: + """Ray Actor wrapper for the Tinker background engine loop.""" + def __init__(self, config: EngineConfig): + self.config = config + + def run(self): + engine = TinkerEngine(self.config) + engine.run() From 8880aa318aea27823d441dbbbca32c219ec0e353 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 27 Mar 2026 18:36:51 +0000 Subject: [PATCH 02/17] add consolidated config classes Signed-off-by: Andrew Sy Kim --- pyproject.toml | 1 + skyrl/tinker/config.py | 32 ++++++++++++++++++++++++++++--- skyrl/tinker/entrypoints/main.py | 33 ++++++++++++++++++-------------- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 543668c926..a2b3f15f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ tinker = [ "aiosqlite", "asyncpg", "psycopg2-binary", + "ray[default]==2.51.1", ] aws = [ diff --git a/skyrl/tinker/config.py b/skyrl/tinker/config.py index 4b1168ee71..0261e268b9 100644 --- a/skyrl/tinker/config.py +++ b/skyrl/tinker/config.py @@ -8,6 +8,7 @@ from cloudpathlib import AnyPath from pydantic import BaseModel, ConfigDict, Field +from skyrl.backends.jax import JaxBackendConfig class EngineConfig(BaseModel): """Configuration for the Tinker engine.""" @@ -18,8 +19,8 @@ class EngineConfig(BaseModel): default=False, description="Whether to use Ray for orchestration (Ray Actors for components)", ) - ray_address: str | None = Field( - default=None, + ray_address: str = Field( + default="auto", description="Address of an existing Ray cluster to connect to. If not set, Ray will be initialized locally.", ) @@ -87,6 +88,10 @@ def add_model(parser: argparse.ArgumentParser, model: type[BaseModel]) -> None: model: The Pydantic model class """ for name, field in model.model_fields.items(): + if isinstance(field.annotation, type) and issubclass(field.annotation, BaseModel): + add_model(parser, field.annotation) + continue + arg_name = name.replace("_", "-") kwargs = { "help": field.description, @@ -125,6 +130,13 @@ def config_to_argv(cfg: BaseModel) -> list[str]: argv = [] for field_name, value in cfg.model_dump().items(): field = cfg.model_fields[field_name] + + if isinstance(field.annotation, type) and issubclass(field.annotation, BaseModel): + nested_cfg = getattr(cfg, field_name) + if nested_cfg is not None: + argv.extend(config_to_argv(nested_cfg)) + continue + arg_name = field_name.replace("_", "-") if field.annotation is bool: @@ -136,7 +148,21 @@ def config_to_argv(cfg: BaseModel) -> list[str]: argv.append(json.dumps(value)) else: # Skip None values - let them use defaults or environment variables - if value is not None: argv.append(f"--{arg_name}") argv.append(str(value)) return argv + + +class APIConfig(BaseModel): + """Configuration for the Tinker API server.""" + + host: str = Field(default="0.0.0.0", description="Host to bind to for API Server") + port: int = Field(default=8000, description="Port to bind to for API Server") + + +class SkyRLTxConfig(BaseModel): + """Top-level configuration for the SkyRL Tinker orchestration.""" + + api: APIConfig = Field(default_factory=APIConfig, description="API server configuration") + engine: EngineConfig = Field(description="Engine configuration") + backend: JaxBackendConfig = Field(default_factory=JaxBackendConfig, description="JAX backend configuration") diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py index 3fd94a1917..cb97293fd5 100644 --- a/skyrl/tinker/entrypoints/main.py +++ b/skyrl/tinker/entrypoints/main.py @@ -1,33 +1,38 @@ import argparse import ray -from skyrl.tinker.config import EngineConfig, add_model +from skyrl.tinker.config import APIConfig, EngineConfig, SkyRLTxConfig, add_model +from skyrl.backends.jax import JaxBackendConfig from skyrl.tinker.ray_actors import TinkerAPIActor, TinkerEngineActor from skyrl.utils.log import logger def main(): + logger.info("Starting entrypoint...") parser = argparse.ArgumentParser(description="SkyRL Tinker Ray Orchestrator") - add_model(parser, EngineConfig) - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to for API Server") - parser.add_argument("--port", type=int, default=8000, help="Port to bind to for API Server") + add_model(parser, SkyRLTxConfig) args = parser.parse_args() - # Create EngineConfig from parsed arguments - config = EngineConfig.model_validate({k: v for k, v in vars(args).items() if k in EngineConfig.model_fields}) + # Create config from parsed arguments + api_config = APIConfig.model_validate({k: v for k, v in vars(args).items() if k in APIConfig.model_fields}) + engine_config = EngineConfig.model_validate({k: v for k, v in vars(args).items() if k in EngineConfig.model_fields}) + jax_backend_config = JaxBackendConfig.model_validate({k: v for k, v in vars(args).items() if k in JaxBackendConfig.model_fields}) + + engine_config.backend_config = jax_backend_config.model_dump() + config = SkyRLTxConfig(api=api_config, engine=engine_config, jax_backend=jax_backend_config) # Force Ray orchestrated mode and ray_jax backend - config.use_ray = True - config.backend = "ray_jax" + config.engine.use_ray = True + config.engine.backend = "ray_jax" - logger.info(f"Initializing Ray with address: {config.ray_address or 'local'}") - ray.init(address=config.ray_address) + logger.info(f"Initializing Ray with address: {config.engine.ray_address or 'local'}") + ray.init(address=config.engine.ray_address) - logger.info(f"Starting Tinker API Actor on {args.host}:{args.port}") - api_actor = TinkerAPIActor.remote(config) - api_task = api_actor.run.remote(args.host, args.port) + logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port}") + api_actor = TinkerAPIActor.remote(config.engine) + api_task = api_actor.run.remote(config.api.host, config.api.port) logger.info("Starting Tinker Engine Actor") - engine_actor = TinkerEngineActor.remote(config) + engine_actor = TinkerEngineActor.remote(config.engine) engine_task = engine_actor.run.remote() logger.info("Ray Orchestrator running. Waiting for actors to complete.") From 0865d835ae4c5ba8ee9521725b9cf5bc74f0a395 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 27 Mar 2026 19:07:52 +0000 Subject: [PATCH 03/17] extract config base types for argparse Signed-off-by: Andrew Sy Kim --- skyrl/tinker/config.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/skyrl/tinker/config.py b/skyrl/tinker/config.py index 0261e268b9..1c77312f32 100644 --- a/skyrl/tinker/config.py +++ b/skyrl/tinker/config.py @@ -3,6 +3,8 @@ import argparse import json import os +import types +import typing from pathlib import Path from cloudpathlib import AnyPath @@ -113,7 +115,13 @@ def add_model(parser: argparse.ArgumentParser, model: type[BaseModel]) -> None: if argparse_type is not None: kwargs["type"] = argparse_type elif field.annotation is not None: - kwargs["type"] = field.annotation + # Extract base type for argparse (unwrap from Optional / Union) + origin = typing.get_origin(field.annotation) + if origin is typing.Union or (hasattr(types, "UnionType") and origin is types.UnionType): + base_type = next((arg for arg in typing.get_args(field.annotation) if arg is not type(None)), str) + kwargs["type"] = base_type + else: + kwargs["type"] = field.annotation if field.is_required(): # Mark as required in argparse if no default is provided From 51037525a19fe7c1eb3dddc1c6681d2d037b0401 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 27 Mar 2026 21:39:20 +0000 Subject: [PATCH 04/17] fix cloudpickle and FastAPI error Signed-off-by: Andrew Sy Kim --- skyrl/tinker/ray_actors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl/tinker/ray_actors.py b/skyrl/tinker/ray_actors.py index ead654b9df..90eda38982 100644 --- a/skyrl/tinker/ray_actors.py +++ b/skyrl/tinker/ray_actors.py @@ -1,8 +1,6 @@ import ray import uvicorn from skyrl.tinker.config import EngineConfig -from skyrl.tinker.engine import TinkerEngine -from skyrl.tinker.api import app @ray.remote class TinkerAPIActor: @@ -11,6 +9,7 @@ def __init__(self, config: EngineConfig): self.config = config def run(self, host: str, port: int): + from skyrl.tinker.api import app app.state.engine_config = self.config # Logging config can be customized if needed from skyrl.utils.log import get_uvicorn_log_config @@ -24,5 +23,6 @@ def __init__(self, config: EngineConfig): self.config = config def run(self): + from skyrl.tinker.engine import TinkerEngine engine = TinkerEngine(self.config) engine.run() From f719ce5f5a3d867d80ad565c4e6e75e57f18315d Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 27 Mar 2026 21:51:47 +0000 Subject: [PATCH 05/17] co-locate engine and API server with STRICT_PACK placement group Signed-off-by: Andrew Sy Kim --- skyrl/tinker/entrypoints/main.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py index cb97293fd5..bf00bb4b86 100644 --- a/skyrl/tinker/entrypoints/main.py +++ b/skyrl/tinker/entrypoints/main.py @@ -27,12 +27,23 @@ def main(): logger.info(f"Initializing Ray with address: {config.engine.ray_address or 'local'}") ray.init(address=config.engine.ray_address) + logger.info("Creating STRICT_PACK placement group to colocate API and Engine actors...") + from ray.util.placement_group import placement_group + pg = placement_group([{"CPU": 1}, {"CPU": 1}], strategy="STRICT_PACK") + ray.get(pg.ready()) + logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port}") - api_actor = TinkerAPIActor.remote(config.engine) + api_actor = TinkerAPIActor.options( + placement_group=pg, + placement_group_bundle_index=0 + ).remote(config.engine) api_task = api_actor.run.remote(config.api.host, config.api.port) logger.info("Starting Tinker Engine Actor") - engine_actor = TinkerEngineActor.remote(config.engine) + engine_actor = TinkerEngineActor.options( + placement_group=pg, + placement_group_bundle_index=1 + ).remote(config.engine) engine_task = engine_actor.run.remote() logger.info("Ray Orchestrator running. Waiting for actors to complete.") From 2d755ff0b912cac847eb16ff32c83dc1431af060 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Mon, 30 Mar 2026 16:57:37 +0000 Subject: [PATCH 06/17] detached actors Signed-off-by: Andrew Sy Kim --- skyrl/tinker/entrypoints/main.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py index bf00bb4b86..b90835816d 100644 --- a/skyrl/tinker/entrypoints/main.py +++ b/skyrl/tinker/entrypoints/main.py @@ -32,26 +32,25 @@ def main(): pg = placement_group([{"CPU": 1}, {"CPU": 1}], strategy="STRICT_PACK") ray.get(pg.ready()) - logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port}") + logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port} (Detached)...") api_actor = TinkerAPIActor.options( placement_group=pg, - placement_group_bundle_index=0 + placement_group_bundle_index=0, + name="tinker_api", + lifetime="detached" ).remote(config.engine) - api_task = api_actor.run.remote(config.api.host, config.api.port) + api_actor.run.remote(config.api.host, config.api.port) - logger.info("Starting Tinker Engine Actor") + logger.info("Starting Tinker Engine Actor (Detached)...") engine_actor = TinkerEngineActor.options( placement_group=pg, - placement_group_bundle_index=1 + placement_group_bundle_index=1, + name="tinker_engine", + lifetime="detached" ).remote(config.engine) - engine_task = engine_actor.run.remote() - - logger.info("Ray Orchestrator running. Waiting for actors to complete.") - try: - ray.get([api_task, engine_task]) - except KeyboardInterrupt: - logger.info("Interrupted. Shutting down Ray...") - ray.shutdown() + engine_actor.run.remote() + + logger.info("Ray actors started in detached mode. They will keep running. You can now run your training script.") if __name__ == "__main__": From 097cc04214b416bb8b951d540b9788cc8634de6e Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Mon, 30 Mar 2026 17:24:57 +0000 Subject: [PATCH 07/17] add run_ray_detached_actors function Signed-off-by: Andrew Sy Kim --- skyrl/tinker/entrypoints/main.py | 69 ++++++++++++++++++++------------ skyrl/tinker/ray.py | 61 ++++++++++++++++++++++++++++ skyrl/tinker/ray_actors.py | 28 ------------- 3 files changed, 105 insertions(+), 53 deletions(-) create mode 100644 skyrl/tinker/ray.py delete mode 100644 skyrl/tinker/ray_actors.py diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py index b90835816d..1de694b5e3 100644 --- a/skyrl/tinker/entrypoints/main.py +++ b/skyrl/tinker/entrypoints/main.py @@ -1,11 +1,30 @@ +import time import argparse import ray from skyrl.tinker.config import APIConfig, EngineConfig, SkyRLTxConfig, add_model from skyrl.backends.jax import JaxBackendConfig -from skyrl.tinker.ray_actors import TinkerAPIActor, TinkerEngineActor +from skyrl.tinker.ray import run_ray_detached_actors from skyrl.utils.log import logger +import tinker +import numpy as np +from tinker import types + + +def process_example(example, tokenizer): + prompt = f"English: {example['input']}\nPig Latin:" + prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True) + completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False) + + tokens = prompt_tokens + completion_tokens + weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens) + + return types.Datum( + model_input=types.ModelInput.from_ints(tokens=tokens[:-1]), + loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:]) + ) + def main(): logger.info("Starting entrypoint...") parser = argparse.ArgumentParser(description="SkyRL Tinker Ray Orchestrator") @@ -26,31 +45,31 @@ def main(): logger.info(f"Initializing Ray with address: {config.engine.ray_address or 'local'}") ray.init(address=config.engine.ray_address) + run_ray_detached_actors(config) + + time.sleep(100) + + service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy") + training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-0.6B") + tokenizer = training_client.get_tokenizer() + + # Training examples + examples = [ + {"input": "banana split", "output": "anana-bay plit-say"}, + {"input": "quantum physics", "output": "uantum-qay ysics-phay"}, + {"input": "coding wizard", "output": "oding-cay izard-way"}, + ] + + processed = [process_example(ex, tokenizer) for ex in examples] + + # Training loop + for _ in range(6): + fwdbwd = training_client.forward_backward(processed, "cross_entropy").result() + training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() - logger.info("Creating STRICT_PACK placement group to colocate API and Engine actors...") - from ray.util.placement_group import placement_group - pg = placement_group([{"CPU": 1}, {"CPU": 1}], strategy="STRICT_PACK") - ray.get(pg.ready()) - - logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port} (Detached)...") - api_actor = TinkerAPIActor.options( - placement_group=pg, - placement_group_bundle_index=0, - name="tinker_api", - lifetime="detached" - ).remote(config.engine) - api_actor.run.remote(config.api.host, config.api.port) - - logger.info("Starting Tinker Engine Actor (Detached)...") - engine_actor = TinkerEngineActor.options( - placement_group=pg, - placement_group_bundle_index=1, - name="tinker_engine", - lifetime="detached" - ).remote(config.engine) - engine_actor.run.remote() - - logger.info("Ray actors started in detached mode. They will keep running. You can now run your training script.") + logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs]) + weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed]) + print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}") if __name__ == "__main__": diff --git a/skyrl/tinker/ray.py b/skyrl/tinker/ray.py new file mode 100644 index 0000000000..e21ea4992b --- /dev/null +++ b/skyrl/tinker/ray.py @@ -0,0 +1,61 @@ +import ray +import uvicorn + +from skyrl.tinker.config import EngineConfig, SkyRLTxConfig +from ray.util.placement_group import placement_group +from skyrl.utils.log import logger + + +def run_ray_detached_actors(config: SkyRLTxConfig): + logger.info("Creating STRICT_PACK placement group to colocate API and Engine actors...") + + pg = placement_group([{"CPU": 1}, {"CPU": 1}], strategy="STRICT_PACK") + ray.get(pg.ready()) + + logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port} (Detached)...") + api_actor = TinkerAPIActor.options( + placement_group=pg, + placement_group_bundle_index=0, + name="tinker_api", + lifetime="detached" + ).remote(config.engine) + api_actor.run.remote(config.api.host, config.api.port) + + logger.info("Starting Tinker Engine Actor (Detached)...") + engine_actor = TinkerEngineActor.options( + placement_group=pg, + placement_group_bundle_index=1, + name="tinker_engine", + lifetime="detached" + ).remote(config.engine) + engine_actor.run.remote() + + logger.info("Ray actors started in detached mode. They will keep running. You can now run your training script.") + + +@ray.remote +class TinkerAPIActor: + """Ray Actor wrapper for the Tinker API server (FastAPI + Uvicorn).""" + def __init__(self, config: EngineConfig): + self.config = config + + def run(self, host: str, port: int): + from skyrl.tinker.api import app + app.state.engine_config = self.config + # Logging config can be customized if needed + from skyrl.utils.log import get_uvicorn_log_config + uvicorn.run(app, host=host, port=port, log_config=get_uvicorn_log_config()) + + + +@ray.remote +class TinkerEngineActor: + """Ray Actor wrapper for the Tinker background engine loop.""" + def __init__(self, config: EngineConfig): + self.config = config + + def run(self): + from skyrl.tinker.engine import TinkerEngine + engine = TinkerEngine(self.config) + engine.run() + diff --git a/skyrl/tinker/ray_actors.py b/skyrl/tinker/ray_actors.py deleted file mode 100644 index 90eda38982..0000000000 --- a/skyrl/tinker/ray_actors.py +++ /dev/null @@ -1,28 +0,0 @@ -import ray -import uvicorn -from skyrl.tinker.config import EngineConfig - -@ray.remote -class TinkerAPIActor: - """Ray Actor wrapper for the Tinker API server (FastAPI + Uvicorn).""" - def __init__(self, config: EngineConfig): - self.config = config - - def run(self, host: str, port: int): - from skyrl.tinker.api import app - app.state.engine_config = self.config - # Logging config can be customized if needed - from skyrl.utils.log import get_uvicorn_log_config - uvicorn.run(app, host=host, port=port, log_config=get_uvicorn_log_config()) - - -@ray.remote -class TinkerEngineActor: - """Ray Actor wrapper for the Tinker background engine loop.""" - def __init__(self, config: EngineConfig): - self.config = config - - def run(self): - from skyrl.tinker.engine import TinkerEngine - engine = TinkerEngine(self.config) - engine.run() From e82a309be5d7e92d22dd0cf508e09c3943a761e6 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Mon, 30 Mar 2026 17:54:22 +0000 Subject: [PATCH 08/17] get remote tinker address Signed-off-by: Andrew Sy Kim --- skyrl/tinker/entrypoints/main.py | 10 +++++----- skyrl/tinker/ray.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py index 1de694b5e3..be18d7628e 100644 --- a/skyrl/tinker/entrypoints/main.py +++ b/skyrl/tinker/entrypoints/main.py @@ -37,6 +37,7 @@ def main(): jax_backend_config = JaxBackendConfig.model_validate({k: v for k, v in vars(args).items() if k in JaxBackendConfig.model_fields}) engine_config.backend_config = jax_backend_config.model_dump() + engine_config.database_url = "sqlite:////tmp/tinker.db" config = SkyRLTxConfig(api=api_config, engine=engine_config, jax_backend=jax_backend_config) # Force Ray orchestrated mode and ray_jax backend @@ -45,11 +46,10 @@ def main(): logger.info(f"Initializing Ray with address: {config.engine.ray_address or 'local'}") ray.init(address=config.engine.ray_address) - run_ray_detached_actors(config) - - time.sleep(100) - - service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy") + tinker_address = run_ray_detached_actors(config) + logger.info(f"Tinker address: {tinker_address}") + time.sleep(120) + service_client = tinker.ServiceClient(base_url=f"http://{tinker_address}:8000", api_key="tml-dummy") training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-0.6B") tokenizer = training_client.get_tokenizer() diff --git a/skyrl/tinker/ray.py b/skyrl/tinker/ray.py index e21ea4992b..1934b7118b 100644 --- a/skyrl/tinker/ray.py +++ b/skyrl/tinker/ray.py @@ -9,28 +9,30 @@ def run_ray_detached_actors(config: SkyRLTxConfig): logger.info("Creating STRICT_PACK placement group to colocate API and Engine actors...") - pg = placement_group([{"CPU": 1}, {"CPU": 1}], strategy="STRICT_PACK") + pg = placement_group([{"CPU": 4, "TPU": 8}], strategy="STRICT_PACK") ray.get(pg.ready()) logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port} (Detached)...") api_actor = TinkerAPIActor.options( + num_cpus=1, placement_group=pg, - placement_group_bundle_index=0, name="tinker_api", lifetime="detached" ).remote(config.engine) + address = ray.get(api_actor.get_ip_address.remote()) api_actor.run.remote(config.api.host, config.api.port) logger.info("Starting Tinker Engine Actor (Detached)...") engine_actor = TinkerEngineActor.options( + num_cpus=1, placement_group=pg, - placement_group_bundle_index=1, name="tinker_engine", lifetime="detached" ).remote(config.engine) engine_actor.run.remote() logger.info("Ray actors started in detached mode. They will keep running. You can now run your training script.") + return address @ray.remote @@ -46,6 +48,9 @@ def run(self, host: str, port: int): from skyrl.utils.log import get_uvicorn_log_config uvicorn.run(app, host=host, port=port, log_config=get_uvicorn_log_config()) + def get_ip_address(self): + return ray.util.get_node_ip_address() + @ray.remote From 1abea0b77d8a14f219e8f15730fb3257a7e27dcb Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Mon, 30 Mar 2026 20:32:18 +0000 Subject: [PATCH 09/17] co-locate RayJaxCoordinator with engine for now Signed-off-by: Andrew Sy Kim --- skyrl/backends/ray_jax.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/skyrl/backends/ray_jax.py b/skyrl/backends/ray_jax.py index 70fb98205b..2791d5ea24 100644 --- a/skyrl/backends/ray_jax.py +++ b/skyrl/backends/ray_jax.py @@ -109,7 +109,16 @@ def __init__(self, base_model: str, config: JaxBackendConfig): logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}") # Instantiate the coordinator but do not run setup yet (to avoid blocking) - self.actor = RayJaxCoordinatorActor.remote(self.base_model, self.config) + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + current_node_id = ray.get_runtime_context().get_node_id() + logger.info(f"Scheduling RayJaxCoordinatorActor on node {current_node_id} to co-locate with TinkerEngine") + self.actor = RayJaxCoordinatorActor.options( + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=current_node_id, + soft=False, + ) + ).remote(self.base_model, self.config) # Retrieve dynamically allocated coordinator address from actor coordinator_address = ray.get(self.actor.get_coordinator_address.remote()) From 0a85a44942a0b942cdef5f77f8f7a1a6b07cce8d Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Tue, 31 Mar 2026 02:04:23 +0000 Subject: [PATCH 10/17] introduce RayJaxBackend to proxy all JAX commands to Ray actors, tinker API and engine now run on driver Signed-off-by: Andrew Sy Kim --- skyrl/backends/jax.py | 11 +++ skyrl/backends/ray_jax.py | 151 +++++++++++++++++-------------- skyrl/tinker/engine.py | 2 +- skyrl/tinker/entrypoints/main.py | 36 ++++++-- 4 files changed, 125 insertions(+), 75 deletions(-) diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index bb9a9de69f..ec1131de16 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -20,6 +20,7 @@ uv run -m skyrl.backends.jax --coordinator-address localhost:7777 --num-processes 2 --process-id 1 """ +import json import time from contextlib import contextmanager from dataclasses import dataclass @@ -112,6 +113,16 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=None, description="Total number of processes in the multi-node cluster", ) + ray_actor_options: dict = Field( + default_factory=dict, + description="Options to pass to Ray actors (e.g., resources, num_cpus)", + json_schema_extra={"argparse_type": json.loads}, + ) + ray_placement_group_bundles: list = Field( + default_factory=list, + description="Bundles for the Ray placement group (e.g., [{'CPU': 1}] * num_processes)", + json_schema_extra={"argparse_type": json.loads}, + ) @jax.tree_util.register_dataclass diff --git a/skyrl/backends/ray_jax.py b/skyrl/backends/ray_jax.py index 2791d5ea24..f062146d6d 100644 --- a/skyrl/backends/ray_jax.py +++ b/skyrl/backends/ray_jax.py @@ -3,7 +3,7 @@ from cloudpathlib import AnyPath from skyrl.backends.backend import AbstractBackend -from skyrl.backends.jax import JaxBackend, JaxBackendConfig, run_worker +from skyrl.backends.jax import JaxBackendConfig, JaxBackendImpl from skyrl.tinker import types from skyrl.utils.log import logger @@ -14,32 +14,42 @@ def get_free_port() -> int: @ray.remote -class RayJaxCoordinatorActor: - """Ray Actor wrapper for the JaxBackend coordinator (process_id = 0). +class RayJaxActor: + """Ray Actor wrapper for JaxBackendImpl. - This actor dynamically allocates a port, provides its coordinator_address - to workers, and then blocks initializing jax.distributed until workers join. + Each actor runs JaxBackendImpl and communicates with other actors + via JAX distributed (NCCL) for data parallel operations. """ - def __init__(self, base_model: str, config: JaxBackendConfig): + def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int): self.base_model = base_model - # Use model_copy so we don't accidentally mutate unintended shared state self.config = config.model_copy() - - # Determine coordinator address - self.node_ip = ray.util.get_node_ip_address() - self.port = get_free_port() - self.coordinator_address = f"{self.node_ip}:{self.port}" - - # Update config with dynamically found address - self.config.coordinator_address = self.coordinator_address + self.process_id = process_id self.backend = None + if process_id == 0: + self.node_ip = ray.util.get_node_ip_address() + self.port = get_free_port() + self.coordinator_address = f"{self.node_ip}:{self.port}" + else: + self.coordinator_address = None + def get_coordinator_address(self) -> str: return self.coordinator_address - def setup(self): - """Initializes the backend. Blocks until all workers connect to the coordinator.""" - self.backend = JaxBackend(self.base_model, self.config) + def setup(self, coordinator_address: str | None = None): + """Initializes JAX distributed and creates JaxBackendImpl.""" + import jax # Import here to avoid issues or ensure it's loaded in the actor + + addr = coordinator_address or self.coordinator_address + logger.info(f"Worker {self.process_id} initializing JAX distributed with coordinator {addr}") + + jax.distributed.initialize( + coordinator_address=addr, + num_processes=self.config.num_processes, + process_id=self.process_id, + ) + self.backend = JaxBackendImpl(self.base_model, self.config, self.process_id) + logger.info(f"Worker {self.process_id} JaxBackendImpl initialized.") # ========================================================================= # Proxied Backend Methods @@ -79,25 +89,11 @@ def get_metrics(self) -> types.EngineMetrics: return self.backend.metrics -@ray.remote -class RayJaxWorkerActor: - """Ray Actor wrapper for JaxBackend workers (process_id > 0).""" - def __init__(self, coordinator_address: str, num_processes: int, process_id: int): - self.coordinator_address = coordinator_address - self.num_processes = num_processes - self.process_id = process_id - - def run(self): - """Run the worker loop infinitely.""" - run_worker(self.coordinator_address, self.num_processes, self.process_id) - - class RayJaxBackend(AbstractBackend): """Proxy Backend that orchestrates Ray actors for multi-node JAX execution. - Locally, this class acts like a normal AbstractBackend. Internally, it creates - a RayJaxCoordinatorActor (which internally wraps JaxBackend) to execute work, - and dynamically provisions RayJaxWorkerActors matching `num_processes`. + This class runs in the driver program (Tinker Engine process) and proxies + commands to all JAX workers running as Ray actors. """ def __init__(self, base_model: str, config: JaxBackendConfig): self.base_model = base_model @@ -108,63 +104,82 @@ def __init__(self, base_model: str, config: JaxBackendConfig): logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}") - # Instantiate the coordinator but do not run setup yet (to avoid blocking) - from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - - current_node_id = ray.get_runtime_context().get_node_id() - logger.info(f"Scheduling RayJaxCoordinatorActor on node {current_node_id} to co-locate with TinkerEngine") - self.actor = RayJaxCoordinatorActor.options( - scheduling_strategy=NodeAffinitySchedulingStrategy( - node_id=current_node_id, - soft=False, - ) - ).remote(self.base_model, self.config) - - # Retrieve dynamically allocated coordinator address from actor - coordinator_address = ray.get(self.actor.get_coordinator_address.remote()) + # Instantiate a Ray placement group + from ray.util.placement_group import placement_group + logger.info("Instantiating Ray placement group for JAX workers...") + bundles = self.config.ray_placement_group_bundles + if not bundles: + bundles = [{"CPU": 1}] * num_processes + self.pg = placement_group(bundles, strategy="SPREAD") + ray.get(self.pg.ready()) + + self.workers = [] - self.worker_tasks = [] - if num_processes > 1: - for i in range(1, num_processes): - worker = RayJaxWorkerActor.remote(coordinator_address, num_processes, i) - self.worker_tasks.append(worker.run.remote()) - - # Trigger the coordinator setup, initializing JAX distributed. - # This will block until the workers connect successfully. - ray.get(self.actor.setup.remote()) + # Create worker 0 (coordinator) + w0_options = self.config.ray_actor_options.copy() + w0_options.update({ + "placement_group": self.pg, + "placement_group_bundle_index": 0, + }) + w0 = RayJaxActor.options(**w0_options).remote(self.base_model, self.config, 0) + self.workers.append(w0) + + # Retrieve dynamically allocated coordinator address from actor 0 + coordinator_address = ray.get(w0.get_coordinator_address.remote()) + # Create other workers + for i in range(1, num_processes): + wi_options = self.config.ray_actor_options.copy() + wi_options.update({ + "placement_group": self.pg, + "placement_group_bundle_index": i, + }) + w = RayJaxActor.options(**wi_options).remote(self.base_model, self.config, i) + self.workers.append(w) + + # Trigger setup on all workers + # This will block until JAX distributed is initialized on all workers + setup_refs = [w0.setup.remote()] + for w in self.workers[1:]: + setup_refs.append(w.setup.remote(coordinator_address)) + + ray.get(setup_refs) logger.info("RayJaxBackend is fully initialized and distributed cluster is ready.") @property def metrics(self) -> types.EngineMetrics: - return ray.get(self.actor.get_metrics.remote()) + return ray.get(self.workers[0].get_metrics.remote()) def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: - ray.get(self.actor.create_model.remote(model_id, lora_config)) + ray.get([w.create_model.remote(model_id, lora_config) for w in self.workers]) def delete_model(self, model_id: str) -> None: - ray.get(self.actor.delete_model.remote(model_id)) + ray.get([w.delete_model.remote(model_id) for w in self.workers]) def has_model(self, model_id: str) -> bool: - return ray.get(self.actor.has_model.remote(model_id)) + return ray.get(self.workers[0].has_model.remote(model_id)) def forward_backward(self, prepared_batch: types.PreparedModelPassBatch): - return ray.get(self.actor.forward_backward.remote(prepared_batch)) + results = ray.get([w.forward_backward.remote(prepared_batch) for w in self.workers]) + return results[0] def forward(self, prepared_batch: types.PreparedModelPassBatch): - return ray.get(self.actor.forward.remote(prepared_batch)) + results = ray.get([w.forward.remote(prepared_batch) for w in self.workers]) + return results[0] def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput | types.ErrorResponse: - return ray.get(self.actor.optim_step.remote(model_id, request_data)) + results = ray.get([w.optim_step.remote(model_id, request_data) for w in self.workers]) + return results[0] def sample(self, prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]: - return ray.get(self.actor.sample.remote(prepared_batch)) + results = ray.get([w.sample.remote(prepared_batch) for w in self.workers]) + return results[0] def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None: - ray.get(self.actor.load_checkpoint.remote(checkpoint_path, model_id)) + ray.get([w.load_checkpoint.remote(checkpoint_path, model_id) for w in self.workers]) def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None: - ray.get(self.actor.save_checkpoint.remote(output_path, model_id)) + ray.get([w.save_checkpoint.remote(output_path, model_id) for w in self.workers]) def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None: - ray.get(self.actor.save_sampler_checkpoint.remote(output_path, model_id, persist)) + ray.get([w.save_sampler_checkpoint.remote(output_path, model_id, persist) for w in self.workers]) diff --git a/skyrl/tinker/engine.py b/skyrl/tinker/engine.py index b0c4ffb777..9514d4cb95 100644 --- a/skyrl/tinker/engine.py +++ b/skyrl/tinker/engine.py @@ -175,7 +175,7 @@ def get_backend_classes(backend_name: str): ) return SkyRLTrainBackend, MegatronBackendOverrides - elif backend_name == "ray_jax": + elif backend_name == "ray-jax": from skyrl.backends.jax import JaxBackendConfig from skyrl.backends.ray_jax import RayJaxBackend diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py index be18d7628e..63122ee7c4 100644 --- a/skyrl/tinker/entrypoints/main.py +++ b/skyrl/tinker/entrypoints/main.py @@ -1,11 +1,14 @@ import time import argparse import ray +import threading +import uvicorn from skyrl.tinker.config import APIConfig, EngineConfig, SkyRLTxConfig, add_model from skyrl.backends.jax import JaxBackendConfig -from skyrl.tinker.ray import run_ray_detached_actors from skyrl.utils.log import logger +from skyrl.tinker.api import app +from skyrl.tinker.engine import TinkerEngine import tinker import numpy as np @@ -39,16 +42,37 @@ def main(): engine_config.backend_config = jax_backend_config.model_dump() engine_config.database_url = "sqlite:////tmp/tinker.db" config = SkyRLTxConfig(api=api_config, engine=engine_config, jax_backend=jax_backend_config) - + # Force Ray orchestrated mode and ray_jax backend config.engine.use_ray = True config.engine.backend = "ray_jax" - + logger.info(f"Initializing Ray with address: {config.engine.ray_address or 'local'}") ray.init(address=config.engine.ray_address) - tinker_address = run_ray_detached_actors(config) - logger.info(f"Tinker address: {tinker_address}") - time.sleep(120) + logger.info("Starting Tinker API and Engine in driver process threads...") + + app.state.engine_config = config.engine + + def run_api(): + from skyrl.utils.log import get_uvicorn_log_config + uvicorn.run(app, host=config.api.host, port=config.api.port, log_config=get_uvicorn_log_config()) + + api_thread = threading.Thread(target=run_api, daemon=True) + api_thread.start() + + engine_instance = TinkerEngine(config.engine) + + def run_engine(): + engine_instance.run() + + engine_thread = threading.Thread(target=run_engine, daemon=True) + engine_thread.start() + + tinker_address = "localhost" + logger.info(f"Tinker API and Engine started. API address: {tinker_address}:{config.api.port}") + + logger.info("Waiting for services to initialize...") + time.sleep(30) service_client = tinker.ServiceClient(base_url=f"http://{tinker_address}:8000", api_key="tml-dummy") training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-0.6B") tokenizer = training_client.get_tokenizer() From e5b3778e04b612810a03da9a3ab289f30ef8cdfe Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Tue, 31 Mar 2026 02:19:49 +0000 Subject: [PATCH 11/17] add ray-jax to pyproject.toml Signed-off-by: Andrew Sy Kim --- pyproject.toml | 4 ++++ skyrl/backends/jax.py | 2 +- skyrl/backends/ray_jax.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2b3f15f93..23f782f760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,10 @@ tinker = [ "ray[default]==2.51.1", ] +ray-jax = [ + "ray[default]==2.51.1", +] + aws = [ "cloudpathlib[s3]", ] diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index ec1131de16..f91412bf9d 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -118,7 +118,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): description="Options to pass to Ray actors (e.g., resources, num_cpus)", json_schema_extra={"argparse_type": json.loads}, ) - ray_placement_group_bundles: list = Field( + ray_pg_bundles: list = Field( default_factory=list, description="Bundles for the Ray placement group (e.g., [{'CPU': 1}] * num_processes)", json_schema_extra={"argparse_type": json.loads}, diff --git a/skyrl/backends/ray_jax.py b/skyrl/backends/ray_jax.py index f062146d6d..e216ae47cd 100644 --- a/skyrl/backends/ray_jax.py +++ b/skyrl/backends/ray_jax.py @@ -107,7 +107,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): # Instantiate a Ray placement group from ray.util.placement_group import placement_group logger.info("Instantiating Ray placement group for JAX workers...") - bundles = self.config.ray_placement_group_bundles + bundles = self.config.ray_pg_bundles if not bundles: bundles = [{"CPU": 1}] * num_processes self.pg = placement_group(bundles, strategy="SPREAD") From 6bdb799ea3a97e6b218702483326d04fba099e89 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Tue, 31 Mar 2026 03:09:40 +0000 Subject: [PATCH 12/17] revert actor definitiosn for TinkerAPI and TinkerEngine, remove entrypoint script Signed-off-by: Andrew Sy Kim --- pyproject.toml | 1 - skyrl/tinker/api.py | 96 ++++++++++++++--------------- skyrl/tinker/config.py | 25 -------- skyrl/tinker/entrypoints/main.py | 100 ------------------------------- skyrl/tinker/ray.py | 66 -------------------- 5 files changed, 44 insertions(+), 244 deletions(-) delete mode 100644 skyrl/tinker/entrypoints/main.py delete mode 100644 skyrl/tinker/ray.py diff --git a/pyproject.toml b/pyproject.toml index 23f782f760..7d12e1ceee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ tinker = [ "aiosqlite", "asyncpg", "psycopg2-binary", - "ray[default]==2.51.1", ] ray-jax = [ diff --git a/skyrl/tinker/api.py b/skyrl/tinker/api.py index ecd5695236..f505be77a4 100644 --- a/skyrl/tinker/api.py +++ b/skyrl/tinker/api.py @@ -120,62 +120,54 @@ async def lifespan(app: FastAPI): logger.info("Using internal engine for inference") # Build subprocess command with engine config parameters. - if not app.state.engine_config.use_ray: - parent_cmd = psutil.Process(os.getppid()).cmdline() - cmd = _build_uv_run_cmd_engine(parent_cmd, app.state.engine_config) - - background_engine = await asyncio.create_subprocess_exec(*cmd) - app.state.background_engine = background_engine - logger.info(f"Started background engine with PID {background_engine.pid}: {' '.join(cmd)}") - - shutting_down = False - - async def monitor_engine(): - """Monitor engine process and exit API server if it crashes.""" - exit_code = await background_engine.wait() - if not shutting_down: - logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") - - # Start a background timer that force-exits after timeout. - # Using a thread instead of asyncio task because SIGTERM handling - # may wait for pending asyncio tasks to complete before exiting. - def force_exit(): - logger.warning("Graceful shutdown timed out, forcing exit") - os._exit(1) - - timer = threading.Timer(SHUTDOWN_TIMEOUT_SECONDS, force_exit) - timer.daemon = True - timer.start() - - # Request graceful shutdown. Uvicorn will stop accepting new - # connections and wait for active requests to complete. - # If shutdown doesn't complete in time, force_exit() will terminate. - os.kill(os.getpid(), signal.SIGTERM) - - monitor_task = asyncio.create_task(monitor_engine()) - else: - logger.info("Running in Ray orchestrated mode. Background engine will not be started here.") - shutting_down = False - monitor_task = None + parent_cmd = psutil.Process(os.getppid()).cmdline() + cmd = _build_uv_run_cmd_engine(parent_cmd, app.state.engine_config) + + background_engine = await asyncio.create_subprocess_exec(*cmd) + app.state.background_engine = background_engine + logger.info(f"Started background engine with PID {background_engine.pid}: {' '.join(cmd)}") + + shutting_down = False + + async def monitor_engine(): + """Monitor engine process and exit API server if it crashes.""" + exit_code = await background_engine.wait() + if not shutting_down: + logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") + + # Start a background timer that force-exits after timeout. + # Using a thread instead of asyncio task because SIGTERM handling + # may wait for pending asyncio tasks to complete before exiting. + def force_exit(): + logger.warning("Graceful shutdown timed out, forcing exit") + os._exit(1) + + timer = threading.Timer(SHUTDOWN_TIMEOUT_SECONDS, force_exit) + timer.daemon = True + timer.start() + + # Request graceful shutdown. Uvicorn will stop accepting new + # connections and wait for active requests to complete. + # If shutdown doesn't complete in time, force_exit() will terminate. + os.kill(os.getpid(), signal.SIGTERM) + + monitor_task = asyncio.create_task(monitor_engine()) yield shutting_down = True - if monitor_task: - monitor_task.cancel() - - if getattr(app.state, "background_engine", None): - logger.info(f"Stopping background engine (PID {app.state.background_engine.pid})") - with suppress(ProcessLookupError): - background_engine = app.state.background_engine - background_engine.terminate() - try: - await asyncio.wait_for(background_engine.wait(), timeout=5) - except asyncio.TimeoutError: - logger.warning(f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing") - background_engine.kill() - await background_engine.wait() - logger.info("Background engine stopped") + monitor_task.cancel() + + logger.info(f"Stopping background engine (PID {app.state.background_engine.pid})") + with suppress(ProcessLookupError): + background_engine.terminate() + try: + await asyncio.wait_for(background_engine.wait(), timeout=5) + except asyncio.TimeoutError: + logger.warning(f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing") + background_engine.kill() + await background_engine.wait() + logger.info("Background engine stopped") app = FastAPI(title="Tinker API Mock", version="0.0.1", lifespan=lifespan) diff --git a/skyrl/tinker/config.py b/skyrl/tinker/config.py index 1c77312f32..8b3b527348 100644 --- a/skyrl/tinker/config.py +++ b/skyrl/tinker/config.py @@ -10,22 +10,12 @@ from cloudpathlib import AnyPath from pydantic import BaseModel, ConfigDict, Field -from skyrl.backends.jax import JaxBackendConfig class EngineConfig(BaseModel): """Configuration for the Tinker engine.""" model_config = ConfigDict(extra="forbid") - use_ray: bool = Field( - default=False, - description="Whether to use Ray for orchestration (Ray Actors for components)", - ) - ray_address: str = Field( - default="auto", - description="Address of an existing Ray cluster to connect to. If not set, Ray will be initialized locally.", - ) - base_model: str = Field(..., description="Base model name (e.g., Qwen/Qwen3-0.6B)") backend: str = Field(default="jax", description="Backend to use for training and inference") backend_config: dict = Field( @@ -159,18 +149,3 @@ def config_to_argv(cfg: BaseModel) -> list[str]: argv.append(f"--{arg_name}") argv.append(str(value)) return argv - - -class APIConfig(BaseModel): - """Configuration for the Tinker API server.""" - - host: str = Field(default="0.0.0.0", description="Host to bind to for API Server") - port: int = Field(default=8000, description="Port to bind to for API Server") - - -class SkyRLTxConfig(BaseModel): - """Top-level configuration for the SkyRL Tinker orchestration.""" - - api: APIConfig = Field(default_factory=APIConfig, description="API server configuration") - engine: EngineConfig = Field(description="Engine configuration") - backend: JaxBackendConfig = Field(default_factory=JaxBackendConfig, description="JAX backend configuration") diff --git a/skyrl/tinker/entrypoints/main.py b/skyrl/tinker/entrypoints/main.py deleted file mode 100644 index 63122ee7c4..0000000000 --- a/skyrl/tinker/entrypoints/main.py +++ /dev/null @@ -1,100 +0,0 @@ -import time -import argparse -import ray -import threading -import uvicorn - -from skyrl.tinker.config import APIConfig, EngineConfig, SkyRLTxConfig, add_model -from skyrl.backends.jax import JaxBackendConfig -from skyrl.utils.log import logger -from skyrl.tinker.api import app -from skyrl.tinker.engine import TinkerEngine - -import tinker -import numpy as np -from tinker import types - - -def process_example(example, tokenizer): - prompt = f"English: {example['input']}\nPig Latin:" - prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True) - completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False) - - tokens = prompt_tokens + completion_tokens - weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens) - - return types.Datum( - model_input=types.ModelInput.from_ints(tokens=tokens[:-1]), - loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:]) - ) - -def main(): - logger.info("Starting entrypoint...") - parser = argparse.ArgumentParser(description="SkyRL Tinker Ray Orchestrator") - add_model(parser, SkyRLTxConfig) - args = parser.parse_args() - - # Create config from parsed arguments - api_config = APIConfig.model_validate({k: v for k, v in vars(args).items() if k in APIConfig.model_fields}) - engine_config = EngineConfig.model_validate({k: v for k, v in vars(args).items() if k in EngineConfig.model_fields}) - jax_backend_config = JaxBackendConfig.model_validate({k: v for k, v in vars(args).items() if k in JaxBackendConfig.model_fields}) - - engine_config.backend_config = jax_backend_config.model_dump() - engine_config.database_url = "sqlite:////tmp/tinker.db" - config = SkyRLTxConfig(api=api_config, engine=engine_config, jax_backend=jax_backend_config) - - # Force Ray orchestrated mode and ray_jax backend - config.engine.use_ray = True - config.engine.backend = "ray_jax" - - logger.info(f"Initializing Ray with address: {config.engine.ray_address or 'local'}") - ray.init(address=config.engine.ray_address) - logger.info("Starting Tinker API and Engine in driver process threads...") - - app.state.engine_config = config.engine - - def run_api(): - from skyrl.utils.log import get_uvicorn_log_config - uvicorn.run(app, host=config.api.host, port=config.api.port, log_config=get_uvicorn_log_config()) - - api_thread = threading.Thread(target=run_api, daemon=True) - api_thread.start() - - engine_instance = TinkerEngine(config.engine) - - def run_engine(): - engine_instance.run() - - engine_thread = threading.Thread(target=run_engine, daemon=True) - engine_thread.start() - - tinker_address = "localhost" - logger.info(f"Tinker API and Engine started. API address: {tinker_address}:{config.api.port}") - - logger.info("Waiting for services to initialize...") - time.sleep(30) - service_client = tinker.ServiceClient(base_url=f"http://{tinker_address}:8000", api_key="tml-dummy") - training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-0.6B") - tokenizer = training_client.get_tokenizer() - - # Training examples - examples = [ - {"input": "banana split", "output": "anana-bay plit-say"}, - {"input": "quantum physics", "output": "uantum-qay ysics-phay"}, - {"input": "coding wizard", "output": "oding-cay izard-way"}, - ] - - processed = [process_example(ex, tokenizer) for ex in examples] - - # Training loop - for _ in range(6): - fwdbwd = training_client.forward_backward(processed, "cross_entropy").result() - training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() - - logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs]) - weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed]) - print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}") - - -if __name__ == "__main__": - main() diff --git a/skyrl/tinker/ray.py b/skyrl/tinker/ray.py deleted file mode 100644 index 1934b7118b..0000000000 --- a/skyrl/tinker/ray.py +++ /dev/null @@ -1,66 +0,0 @@ -import ray -import uvicorn - -from skyrl.tinker.config import EngineConfig, SkyRLTxConfig -from ray.util.placement_group import placement_group -from skyrl.utils.log import logger - - -def run_ray_detached_actors(config: SkyRLTxConfig): - logger.info("Creating STRICT_PACK placement group to colocate API and Engine actors...") - - pg = placement_group([{"CPU": 4, "TPU": 8}], strategy="STRICT_PACK") - ray.get(pg.ready()) - - logger.info(f"Starting Tinker API Actor on {config.api.host}:{config.api.port} (Detached)...") - api_actor = TinkerAPIActor.options( - num_cpus=1, - placement_group=pg, - name="tinker_api", - lifetime="detached" - ).remote(config.engine) - address = ray.get(api_actor.get_ip_address.remote()) - api_actor.run.remote(config.api.host, config.api.port) - - logger.info("Starting Tinker Engine Actor (Detached)...") - engine_actor = TinkerEngineActor.options( - num_cpus=1, - placement_group=pg, - name="tinker_engine", - lifetime="detached" - ).remote(config.engine) - engine_actor.run.remote() - - logger.info("Ray actors started in detached mode. They will keep running. You can now run your training script.") - return address - - -@ray.remote -class TinkerAPIActor: - """Ray Actor wrapper for the Tinker API server (FastAPI + Uvicorn).""" - def __init__(self, config: EngineConfig): - self.config = config - - def run(self, host: str, port: int): - from skyrl.tinker.api import app - app.state.engine_config = self.config - # Logging config can be customized if needed - from skyrl.utils.log import get_uvicorn_log_config - uvicorn.run(app, host=host, port=port, log_config=get_uvicorn_log_config()) - - def get_ip_address(self): - return ray.util.get_node_ip_address() - - - -@ray.remote -class TinkerEngineActor: - """Ray Actor wrapper for the Tinker background engine loop.""" - def __init__(self, config: EngineConfig): - self.config = config - - def run(self): - from skyrl.tinker.engine import TinkerEngine - engine = TinkerEngine(self.config) - engine.run() - From 2748609d9ba2241e4a409fc1e206651db9e32d5e Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Tue, 31 Mar 2026 03:25:00 +0000 Subject: [PATCH 13/17] revert config parsing changes that are no longer needed Signed-off-by: Andrew Sy Kim --- skyrl/tinker/config.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/skyrl/tinker/config.py b/skyrl/tinker/config.py index 8b3b527348..fff7eab256 100644 --- a/skyrl/tinker/config.py +++ b/skyrl/tinker/config.py @@ -3,8 +3,6 @@ import argparse import json import os -import types -import typing from pathlib import Path from cloudpathlib import AnyPath @@ -80,10 +78,6 @@ def add_model(parser: argparse.ArgumentParser, model: type[BaseModel]) -> None: model: The Pydantic model class """ for name, field in model.model_fields.items(): - if isinstance(field.annotation, type) and issubclass(field.annotation, BaseModel): - add_model(parser, field.annotation) - continue - arg_name = name.replace("_", "-") kwargs = { "help": field.description, @@ -105,13 +99,7 @@ def add_model(parser: argparse.ArgumentParser, model: type[BaseModel]) -> None: if argparse_type is not None: kwargs["type"] = argparse_type elif field.annotation is not None: - # Extract base type for argparse (unwrap from Optional / Union) - origin = typing.get_origin(field.annotation) - if origin is typing.Union or (hasattr(types, "UnionType") and origin is types.UnionType): - base_type = next((arg for arg in typing.get_args(field.annotation) if arg is not type(None)), str) - kwargs["type"] = base_type - else: - kwargs["type"] = field.annotation + kwargs["type"] = field.annotation if field.is_required(): # Mark as required in argparse if no default is provided @@ -128,13 +116,6 @@ def config_to_argv(cfg: BaseModel) -> list[str]: argv = [] for field_name, value in cfg.model_dump().items(): field = cfg.model_fields[field_name] - - if isinstance(field.annotation, type) and issubclass(field.annotation, BaseModel): - nested_cfg = getattr(cfg, field_name) - if nested_cfg is not None: - argv.extend(config_to_argv(nested_cfg)) - continue - arg_name = field_name.replace("_", "-") if field.annotation is bool: @@ -146,6 +127,7 @@ def config_to_argv(cfg: BaseModel) -> list[str]: argv.append(json.dumps(value)) else: # Skip None values - let them use defaults or environment variables + if value is not None: argv.append(f"--{arg_name}") argv.append(str(value)) return argv From 7a06bba07002c190f9eb1a60048f986f8e83ff8c Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Tue, 31 Mar 2026 03:43:22 +0000 Subject: [PATCH 14/17] improve class and variable naming Signed-off-by: Andrew Sy Kim --- skyrl/backends/jax.py | 1 + skyrl/backends/ray_jax.py | 84 +++++++++++++++++++-------------------- 2 files changed, 42 insertions(+), 43 deletions(-) diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index f91412bf9d..b0311e8c80 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -113,6 +113,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): default=None, description="Total number of processes in the multi-node cluster", ) + # RayJaxBackend configuration ray_actor_options: dict = Field( default_factory=dict, description="Options to pass to Ray actors (e.g., resources, num_cpus)", diff --git a/skyrl/backends/ray_jax.py b/skyrl/backends/ray_jax.py index e216ae47cd..32540d7e28 100644 --- a/skyrl/backends/ray_jax.py +++ b/skyrl/backends/ray_jax.py @@ -1,4 +1,3 @@ -import socket import ray from cloudpathlib import AnyPath @@ -7,19 +6,14 @@ from skyrl.tinker import types from skyrl.utils.log import logger -def get_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - @ray.remote -class RayJaxActor: - """Ray Actor wrapper for JaxBackendImpl. - - Each actor runs JaxBackendImpl and communicates with other actors - via JAX distributed (NCCL) for data parallel operations. +class RayJaxBackendImpl: + """RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl. + + Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl. """ + def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int): self.base_model = base_model self.config = config.model_copy() @@ -28,7 +22,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int): if process_id == 0: self.node_ip = ray.util.get_node_ip_address() - self.port = get_free_port() + self.port = 7777 self.coordinator_address = f"{self.node_ip}:{self.port}" else: self.coordinator_address = None @@ -38,7 +32,7 @@ def get_coordinator_address(self) -> str: def setup(self, coordinator_address: str | None = None): """Initializes JAX distributed and creates JaxBackendImpl.""" - import jax # Import here to avoid issues or ensure it's loaded in the actor + import jax addr = coordinator_address or self.coordinator_address logger.info(f"Worker {self.process_id} initializing JAX distributed with coordinator {addr}") @@ -51,10 +45,6 @@ def setup(self, coordinator_address: str | None = None): self.backend = JaxBackendImpl(self.base_model, self.config, self.process_id) logger.info(f"Worker {self.process_id} JaxBackendImpl initialized.") - # ========================================================================= - # Proxied Backend Methods - # ========================================================================= - def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: self.backend.create_model(model_id, lora_config) @@ -90,23 +80,25 @@ def get_metrics(self) -> types.EngineMetrics: class RayJaxBackend(AbstractBackend): - """Proxy Backend that orchestrates Ray actors for multi-node JAX execution. - - This class runs in the driver program (Tinker Engine process) and proxies + """RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution. + + This class runs in the driver program (along with Tinker API / Engine) and proxies commands to all JAX workers running as Ray actors. """ + def __init__(self, base_model: str, config: JaxBackendConfig): self.base_model = base_model self.config = config.model_copy() - + num_processes = self.config.num_processes or 1 self.config.num_processes = num_processes - + logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}") - # Instantiate a Ray placement group + # Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig from ray.util.placement_group import placement_group - logger.info("Instantiating Ray placement group for JAX workers...") + + logger.info("Creating Ray placement group for JAX backend") bundles = self.config.ray_pg_bundles if not bundles: bundles = [{"CPU": 1}] * num_processes @@ -114,37 +106,41 @@ def __init__(self, base_model: str, config: JaxBackendConfig): ray.get(self.pg.ready()) self.workers = [] - - # Create worker 0 (coordinator) + + # node0 (coordinator) + logger.info("Scheduling Ray actor for node0 (JAX coordinator)") w0_options = self.config.ray_actor_options.copy() - w0_options.update({ - "placement_group": self.pg, - "placement_group_bundle_index": 0, - }) - w0 = RayJaxActor.options(**w0_options).remote(self.base_model, self.config, 0) + w0_options.update( + { + "placement_group": self.pg, + "placement_group_bundle_index": 0, + } + ) + w0 = RayJaxBackendImpl.options(**w0_options).remote(self.base_model, self.config, 0) self.workers.append(w0) - # Retrieve dynamically allocated coordinator address from actor 0 coordinator_address = ray.get(w0.get_coordinator_address.remote()) - - # Create other workers + + # Create remaining node1 - nodeN for multi-node training. + logger.info("Scheduling remaining Ray actors (JAX workers)") for i in range(1, num_processes): wi_options = self.config.ray_actor_options.copy() - wi_options.update({ - "placement_group": self.pg, - "placement_group_bundle_index": i, - }) - w = RayJaxActor.options(**wi_options).remote(self.base_model, self.config, i) + wi_options.update( + { + "placement_group": self.pg, + "placement_group_bundle_index": i, + } + ) + w = RayJaxBackendImpl.options(**wi_options).remote(self.base_model, self.config, i) self.workers.append(w) - # Trigger setup on all workers - # This will block until JAX distributed is initialized on all workers + # This will block until JAX distributed is initialized on all workers. setup_refs = [w0.setup.remote()] for w in self.workers[1:]: setup_refs.append(w.setup.remote(coordinator_address)) ray.get(setup_refs) - logger.info("RayJaxBackend is fully initialized and distributed cluster is ready.") + logger.info("RayJaxBackend is fully initialized and distributed JAX cluster is ready.") @property def metrics(self) -> types.EngineMetrics: @@ -167,7 +163,9 @@ def forward(self, prepared_batch: types.PreparedModelPassBatch): results = ray.get([w.forward.remote(prepared_batch) for w in self.workers]) return results[0] - def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput | types.ErrorResponse: + def optim_step( + self, model_id: str, request_data: types.OptimStepInput + ) -> types.OptimStepOutput | types.ErrorResponse: results = ray.get([w.optim_step.remote(model_id, request_data) for w in self.workers]) return results[0] From c4947065c01bb4ea6786bfadeedf18a13e8ba333 Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Thu, 16 Apr 2026 18:39:10 +0000 Subject: [PATCH 15/17] address comments from Phillip Signed-off-by: Andrew Sy Kim --- pyproject.toml | 2 +- skyrl/backends/jax.py | 5 +++++ skyrl/backends/ray_jax.py | 15 ++++++++++++--- skyrl/tinker/engine.py | 22 ++++++++++++---------- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d12e1ceee..a31530f413 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ tinker = [ "psycopg2-binary", ] -ray-jax = [ +ray = [ "ray[default]==2.51.1", ] diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index b0311e8c80..fe4afd9f47 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -114,6 +114,11 @@ class JaxBackendConfig(BaseModel, extra="forbid"): description="Total number of processes in the multi-node cluster", ) # RayJaxBackend configuration + use_ray: bool = Field( + default=False, + description="Use Ray to schedule JAX workers", + ) + ray_actor_options: dict = Field( default_factory=dict, description="Options to pass to Ray actors (e.g., resources, num_cpus)", diff --git a/skyrl/backends/ray_jax.py b/skyrl/backends/ray_jax.py index 32540d7e28..52cf97937e 100644 --- a/skyrl/backends/ray_jax.py +++ b/skyrl/backends/ray_jax.py @@ -1,3 +1,5 @@ +import socket + import ray from cloudpathlib import AnyPath @@ -7,6 +9,12 @@ from skyrl.utils.log import logger +def _get_random_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + @ray.remote class RayJaxBackendImpl: """RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl. @@ -22,7 +30,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int): if process_id == 0: self.node_ip = ray.util.get_node_ip_address() - self.port = 7777 + self.port = _get_random_port() self.coordinator_address = f"{self.node_ip}:{self.port}" else: self.coordinator_address = None @@ -90,8 +98,9 @@ def __init__(self, base_model: str, config: JaxBackendConfig): self.base_model = base_model self.config = config.model_copy() - num_processes = self.config.num_processes or 1 - self.config.num_processes = num_processes + if not self.config.num_processes: + raise ValueError("num_processes must be specified and > 0 when using Ray JAX backend") + num_processes = self.config.num_processes logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}") diff --git a/skyrl/tinker/engine.py b/skyrl/tinker/engine.py index 9514d4cb95..458f86eb41 100644 --- a/skyrl/tinker/engine.py +++ b/skyrl/tinker/engine.py @@ -155,12 +155,18 @@ def prepare_model_pass_batch( ) -def get_backend_classes(backend_name: str): +def get_backend_classes(backend_name: str, use_ray: bool = False): """Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray).""" if backend_name == "jax": - from skyrl.backends.jax import JaxBackend, JaxBackendConfig + if use_ray: + from skyrl.backends.jax import JaxBackendConfig + from skyrl.backends.ray_jax import RayJaxBackend - return JaxBackend, JaxBackendConfig + return RayJaxBackend, JaxBackendConfig + else: + from skyrl.backends.jax import JaxBackend, JaxBackendConfig + + return JaxBackend, JaxBackendConfig elif backend_name == "fsdp": from skyrl.backends.skyrl_train_backend import ( FSDPBackendOverrides, @@ -175,14 +181,9 @@ def get_backend_classes(backend_name: str): ) return SkyRLTrainBackend, MegatronBackendOverrides - elif backend_name == "ray-jax": - from skyrl.backends.jax import JaxBackendConfig - from skyrl.backends.ray_jax import RayJaxBackend - - return RayJaxBackend, JaxBackendConfig else: raise ValueError( - f"Unknown backend: {backend_name}. Available backends: jax, ray_jax, fsdp, megatron. " + f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. " f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])" ) @@ -241,7 +242,8 @@ def __init__( enable_sqlite_wal(self.db_engine) # Initialize the backend (handles model state, computation, and adapter management) - backend_class, backend_config_class = get_backend_classes(config.backend) + use_ray = config.backend_config.get("use_ray", False) + backend_class, backend_config_class = get_backend_classes(config.backend, use_ray=use_ray) backend_config = backend_config_class(**config.backend_config) self.backend = backend_class(config.base_model, backend_config) From 25db033ea3467d7df8a63ff6982d91374f866a7d Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Thu, 16 Apr 2026 23:53:10 +0000 Subject: [PATCH 16/17] bump ray version to 2.54.1 Signed-off-by: Andrew Sy Kim --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a31530f413..a73d246202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ tinker = [ ] ray = [ - "ray[default]==2.51.1", + "ray[default]==2.54.1", ] aws = [ From febcb9cc0cc54d9e8c3b3c01e4fe6950c7af142f Mon Sep 17 00:00:00 2001 From: Andrew Sy Kim Date: Fri, 17 Apr 2026 00:00:22 +0000 Subject: [PATCH 17/17] revert ray versoin Signed-off-by: Andrew Sy Kim --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a73d246202..a31530f413 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ tinker = [ ] ray = [ - "ray[default]==2.54.1", + "ray[default]==2.51.1", ] aws = [