diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 43a73fc28c..d0ae5af0cc 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -1,7 +1,9 @@ import asyncio import logging import os +import signal import socket +import sys from collections import defaultdict from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p from datetime import timedelta @@ -121,6 +123,29 @@ def init_worker_process_group(self): backend="cpu:gloo,cuda:nccl", timeout=timedelta(seconds=SKYRL_WORKER_NCCL_TIMEOUT_IN_S) ) + # Clean teardown on k8s SIGTERM: drain CUDA streams + release NCCL + # communicators before the 25s grace period elapses. + rank = self._rank + + def _sigterm_cleanup(signum, frame): + logger.warning(f"SIGTERM received in worker rank={rank}, cleaning up...") + + if torch.cuda.is_available(): + try: + torch.cuda.synchronize() + except Exception as e: + logger.warning(f"cuda.synchronize() failed: {e}") + + try: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + except Exception as e: + logger.warning(f"destroy_process_group() failed: {e}") + + sys.exit(0) + + signal.signal(signal.SIGTERM, _sigterm_cleanup) + # setup device mesh # TODO: Support TP / PP for additional backends # NOTE (sumanthrh): Device mesh and mesh rank are rank specific attributes. For the current way the strategy is defined,