From 21fd3c5243734ae4c755ef51be58302aa96c9572 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 07:55:48 -0700 Subject: [PATCH 1/9] remove deprecated IFS class -- this was replaced by generate_fractal_points.py --- ScaFFold/datagen/IteratedFunctionSystem.py | 101 --------------------- 1 file changed, 101 deletions(-) delete mode 100644 ScaFFold/datagen/IteratedFunctionSystem.py diff --git a/ScaFFold/datagen/IteratedFunctionSystem.py b/ScaFFold/datagen/IteratedFunctionSystem.py deleted file mode 100644 index aded2f2..0000000 --- a/ScaFFold/datagen/IteratedFunctionSystem.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. -# Produced at the Lawrence Livermore National Laboratory. -# Written by the LBANN Research Team (B. Van Essen, et al.) listed in -# the CONTRIBUTORS file. See the top-level LICENSE file for details. -# -# LLNL-CODE-697807. -# All rights reserved. -# -# This file is part of LBANN: Livermore Big Artificial Neural Network -# Toolkit. For details, see http://software.llnl.gov/LBANN or -# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. -# -# SPDX-License-Identifier: (Apache-2.0) - -import numpy as np - - -class ifs_function: - def __init__(self): - self.prev_x, self.prev_y, self.prev_z = 0.0, 0.0, 0.0 - self.function = [] - self.xs, self.ys, self.zs = [], [], [] - self.select_function = [] - self.temp_proba = 0.0 - - def set_param(self, a, b, c, d, e, f, g, h, i, j, k, l, proba, **kwargs): # noqa: E741 - if "weight_a" in kwargs: - a *= kwargs["weight_a"] - if "weight_b" in kwargs: - b *= kwargs["weight_b"] - if "weight_c" in kwargs: - c *= kwargs["weight_c"] - if "weight_d" in kwargs: - d *= kwargs["weight_d"] - if "weight_e" in kwargs: - e *= kwargs["weight_e"] - if "weight_f" in kwargs: - f *= kwargs["weight_f"] - if "weight_g" in kwargs: - g *= kwargs["weight_g"] - if "weight_h" in kwargs: - h *= kwargs["weight_h"] - if "weight_i" in kwargs: - i *= kwargs["weight_i"] - if "weight_j" in kwargs: - j *= kwargs["weight_j"] - if "weight_k" in kwargs: - k *= kwargs["weight_k"] - if "weight_l" in kwargs: - l *= kwargs["weight_l"] # noqa: E741 - temp_function = { - "a": a, - "b": b, - "c": c, - "d": d, - "e": e, - "f": f, - "g": g, - "h": h, - "i": i, - "j": j, - "k": k, - "l": l, - "proba": proba, - } - self.function.append(temp_function) - self.temp_proba += proba - self.select_function.append(self.temp_proba) - - def calculate(self, iteration): - """Recursively calculate coordinates for args.iteration""" - rand = np.random.random(iteration) - select_function = self.select_function - function = self.function - prev_x, prev_y, prev_z = self.prev_x, self.prev_y, self.prev_z - for i in range(iteration - 1): - for j in range(len(select_function)): - if rand[i] <= select_function[j]: - next_x = ( - prev_x * function[j]["a"] - + prev_y * function[j]["b"] - + prev_z * function[j]["c"] - + function[j]["j"] - ) - next_y = ( - prev_x * function[j]["d"] - + prev_y * function[j]["e"] - + prev_z * function[j]["f"] - + function[j]["k"] - ) - next_z = ( - prev_x * function[j]["g"] - + prev_y * function[j]["h"] - + prev_z * function[j]["i"] - + function[j]["l"] - ) - break - self.xs.append(next_x), self.ys.append(next_y), self.zs.append(next_z) - prev_x, prev_y, prev_z = next_x, next_y, next_z - point_data = np.array((self.xs, self.ys, self.zs), dtype=float) - return point_data From b545950cf2cf467c0260de3c02fb4964d2631327 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 08:02:04 -0700 Subject: [PATCH 2/9] valueerror should say config_type --- ScaFFold/utils/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 08cb481..1a044a5 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -103,5 +103,5 @@ def load_config(file_path: str, config_type: str): return RunConfig(config_dict) else: raise ValueError( - f"Invalid config type specified: {type}. Must be either 'sweep' or 'run'" + f"Invalid config_type specified: {type}. Must be either 'sweep' or 'run'" ) From a9749a0d8f9260cb3c41b0dd27dc7148c36c8039 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 08:02:30 -0700 Subject: [PATCH 3/9] remove unnecessary copy --- ScaFFold/utils/data_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index 725854c..27ce299 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -91,8 +91,8 @@ def __getitem__(self, idx): mask = self.preprocess(self.mask_values, mask, is_mask=True) return { - "image": torch.as_tensor(img.copy()).float().contiguous(), - "mask": torch.as_tensor(mask.copy()).long().contiguous(), + "image": torch.as_tensor(img).float().contiguous(), + "mask": torch.as_tensor(mask).long().contiguous(), } From 8626b84b793f687edf39ebeafdc93c749969fc75 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 08:03:08 -0700 Subject: [PATCH 4/9] images.to device should be non-blocking --- ScaFFold/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 29d6807..aa4313d 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -342,7 +342,7 @@ def train(self): device=self.device, dtype=torch.float32, memory_format=torch.channels_last_3d, - non_blocking=False, + non_blocking=True, ) images_dc = DCTensor.distribute(images, ps) From 123c90d8cd8e129277365d995ad496d1e0fd6f69 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 08:46:16 -0700 Subject: [PATCH 5/9] apply optimizer every batch, not every epoch; unscale gradients before clipping --- ScaFFold/utils/trainer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index aa4313d..d0bdca7 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -131,7 +131,7 @@ def create_dataloaders(self): self.create_sampler() loader_args = dict( - batch_size=self.config.batch_size, num_workers=1, pin_memory=True + batch_size=self.config.batch_size, num_workers=4, pin_memory=True ) self.log.debug( f"dataloader num_workers={loader_args['num_workers']}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " @@ -434,6 +434,7 @@ def train(self): self.train_loader.sampler.set_epoch(epoch) self.val_loader.sampler.set_epoch(epoch) self.model.train() + self.optimizer.zero_grad(set_to_none=False) estr = ( f"{epoch}" @@ -447,8 +448,6 @@ def train(self): unit="img", disable=True if self.world_rank != 0 else False, ) as pbar: - batch_step = 0 - begin_code_region("batch_loop") for batch in self.train_loader: images, true_masks = batch["image"], batch["mask"] @@ -561,22 +560,20 @@ def train(self): gather_and_print_mem(self.log, "post_backward") begin_code_region("step_and_update") - if batch_step + 1 == len(self.train_loader): - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=1.0 - ) - self.grad_scaler.step(self.optimizer) - gather_and_print_mem(self.log, "after_optim_step") - - self.grad_scaler.update() - self.optimizer.zero_grad(set_to_none=False) + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=1.0 + ) + self.grad_scaler.step(self.optimizer) + gather_and_print_mem(self.log, "after_optim_step") + self.grad_scaler.update() + self.optimizer.zero_grad(set_to_none=False) end_code_region("step_and_update") # Update the loss begin_code_region("update_loss") pbar.update(images_dp.shape[0]) self.global_step += 1 - batch_step += 1 # Stay on GPU epoch_loss += loss.detach() end_code_region("update_loss") From 4a8e7a288b90948c1a5843098d19895aa217b73d Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 08:56:17 -0700 Subject: [PATCH 6/9] shift dataloader preprocessing work into dataset generation for speedup, maintaining support for old datasets --- ScaFFold/datagen/get_dataset.py | 9 +++- ScaFFold/datagen/volumegen.py | 10 +++-- ScaFFold/utils/data_loading.py | 80 ++++++++++++++++++++++++--------- 3 files changed, 73 insertions(+), 26 deletions(-) diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index fc19f8c..2e74abe 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -29,7 +29,9 @@ from ScaFFold.datagen import volumegen META_FILENAME = "meta.yaml" +DATASET_FORMAT_VERSION = 2 INCLUDE_KEYS = [ + "dataset_format_version", "n_categories", "n_instances_used_per_fractal", "problem_scale", @@ -116,8 +118,10 @@ def get_dataset( root.mkdir(exist_ok=True) # Get dict of required keys and compute config_id + config_dict = vars(config).copy() + config_dict["dataset_format_version"] = DATASET_FORMAT_VERSION volume_config = _get_required_keys_dict( - config=vars(config), include_keys=INCLUDE_KEYS + config=config_dict, include_keys=INCLUDE_KEYS ) config_id = _hash_volume_config(volume_config) commit = _git_commit_short() @@ -136,6 +140,8 @@ def get_dataset( meta = yaml.safe_load(meta_path.read_text()) if meta.get("config_id") != config_id: continue + if meta.get("dataset_format_version", 1) != DATASET_FORMAT_VERSION: + continue if require_commit and meta.get("code_commit") != commit: continue # If we pass the above checks, this dataset can be reused @@ -186,6 +192,7 @@ def get_dataset( # Write to tmp, then move, so readers never see half-written dataset meta = { "config_id": config_id, + "dataset_format_version": DATASET_FORMAT_VERSION, "config_subset": volume_config, "include_keys": INCLUDE_KEYS, "code_commit": commit, diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index 479e67e..59dae15 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -180,7 +180,7 @@ def main(config: Dict): dtype=np.float32, ) mask = np.full( - (config.vol_size, config.vol_size, config.vol_size), 0, dtype=np.short + (config.vol_size, config.vol_size, config.vol_size), 0, dtype=np.int64 ) global_vol_idx = curr_vol[0] @@ -223,14 +223,18 @@ def main(config: Dict): # Determine destination folder subdir = "validation" if global_vol_idx in val_indices else "training" + volume_to_save = np.ascontiguousarray( + volume.transpose((3, 0, 1, 2)), dtype=np.float32 + ) + mask_to_save = np.ascontiguousarray(mask, dtype=np.int64) vol_file = os.path.join(vol_path, subdir, f"{global_vol_idx}.npy") with open(vol_file, "wb") as f: - np.save(f, volume) + np.save(f, volume_to_save) mask_file = os.path.join(mask_path, subdir, f"{global_vol_idx}_mask.npy") with open(mask_file, "wb") as f: - np.save(f, mask) + np.save(f, mask_to_save) end_time = time.time() total_time = end_time - start_time diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index 27ce299..962f01f 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -20,9 +20,14 @@ import numpy as np import torch from torch.utils.data import Dataset +import yaml from ScaFFold.utils.utils import customlog +DATASET_FORMAT_VERSION = 2 +LEGACY_DATASET_FORMAT_VERSION = 1 +META_FILENAME = "meta.yaml" + class BasicDataset(Dataset): def __init__( @@ -31,6 +36,8 @@ def __init__( self.images_dir = Path(images_dir) self.mask_dir = Path(mask_dir) self.mask_suffix = mask_suffix + self.dataset_root = self.images_dir.parents[1] + self.dataset_format_version = self._load_dataset_format_version() self.ids = [ splitext(file)[0] @@ -49,25 +56,54 @@ def __init__( data = pickle.load(data_file) self.mask_values = data["mask_values"] customlog(f"Unique mask values: {self.mask_values}") + customlog(f"Dataset format version: {self.dataset_format_version}") def __len__(self): return len(self.ids) @staticmethod - def preprocess(mask_values, img, is_mask): - if is_mask: - mask = np.zeros((img.shape[0], img.shape[1], img.shape[2]), dtype=np.short) - for i, v in enumerate(mask_values): - if img.ndim == 3: - mask[img == v] = i - else: - mask[(img == v).all(-1)] = i + def _load_numpy_array(path): + with open(path, "rb") as handle: + return np.load(handle) + + def _load_dataset_format_version(self): + meta_path = self.dataset_root / META_FILENAME + if not meta_path.exists(): + return LEGACY_DATASET_FORMAT_VERSION + + try: + with open(meta_path, "r") as meta_file: + meta = yaml.safe_load(meta_file) or {} + except Exception as exc: + customlog( + f"Failed to read dataset metadata from {meta_path}: {exc}. Falling back to legacy loader." + ) + return LEGACY_DATASET_FORMAT_VERSION - return mask + return int(meta.get("dataset_format_version", LEGACY_DATASET_FORMAT_VERSION)) - else: - img = img.transpose((3, 0, 1, 2)) - return img + @staticmethod + def _prepare_legacy_image(img): + return np.ascontiguousarray(img.transpose((3, 0, 1, 2)), dtype=np.float32) + + @staticmethod + def _prepare_legacy_mask(mask_values, mask): + remapped = np.zeros((mask.shape[0], mask.shape[1], mask.shape[2]), dtype=np.int64) + for i, value in enumerate(mask_values): + if mask.ndim == 3: + remapped[mask == value] = i + else: + remapped[(mask == value).all(-1)] = i + + return remapped + + @staticmethod + def _prepare_optimized_image(img): + return np.ascontiguousarray(img, dtype=np.float32) + + @staticmethod + def _prepare_optimized_mask(mask): + return np.ascontiguousarray(mask, dtype=np.int64) def __getitem__(self, idx): name = self.ids[idx] @@ -80,19 +116,19 @@ def __getitem__(self, idx): assert len(mask_file) == 1, ( f"Either no mask or multiple masks found for the ID {name}: {mask_file}" ) - with open(mask_file[0], "rb") as f: - mask = np.load(f) - f.close() - with open(img_file[0], "rb") as f: - img = np.load(f) - f.close() + mask = self._load_numpy_array(mask_file[0]) + img = self._load_numpy_array(img_file[0]) - img = self.preprocess(self.mask_values, img, is_mask=False) - mask = self.preprocess(self.mask_values, mask, is_mask=True) + if self.dataset_format_version >= DATASET_FORMAT_VERSION: + img = self._prepare_optimized_image(img) + mask = self._prepare_optimized_mask(mask) + else: + img = self._prepare_legacy_image(img) + mask = self._prepare_legacy_mask(self.mask_values, mask) return { - "image": torch.as_tensor(img).float().contiguous(), - "mask": torch.as_tensor(mask).long().contiguous(), + "image": torch.from_numpy(img), + "mask": torch.from_numpy(mask), } From 2968cc1b625747cce0805fe8c4c71daa9f0fd264 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 09:00:54 -0700 Subject: [PATCH 7/9] make dataloader num_workers user-configurable --- ScaFFold/cli.py | 5 +++++ ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/trainer.py | 10 ++++++++-- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 840a7e6..588c0cd 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -140,6 +140,11 @@ def main(): benchmark_parser.add_argument( "--batch-size", type=int, nargs="+", help="Batch sizes for each volume size." ) + benchmark_parser.add_argument( + "--dataloader-num-workers", + type=int, + help="Number of DataLoader worker processes per rank.", + ) benchmark_parser.add_argument( "--optimizer", type=str, diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index fce1042..5fa12e8 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -8,6 +8,7 @@ problem_scale: 6 # Determines dataset resolution and number of unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. +dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: 2 # DistConv param: dimension on which to shard diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 1a044a5..aae85b2 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -50,6 +50,7 @@ def __init__(self, config_dict): self.n_instances_used_per_fractal = config_dict["n_instances_used_per_fractal"] self.scale = 1 self.batch_size = config_dict["batch_size"] + self.dataloader_num_workers = config_dict.get("dataloader_num_workers", 4) self.epochs = config_dict["epochs"] self.optimizer = config_dict["optimizer"] self.disable_scheduler = bool(config_dict["disable_scheduler"]) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index d0bdca7..b943243 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -130,11 +130,17 @@ def create_dataloaders(self): self.create_dataset() self.create_sampler() + num_workers = self.config.dataloader_num_workers loader_args = dict( - batch_size=self.config.batch_size, num_workers=4, pin_memory=True + batch_size=self.config.batch_size, + num_workers=num_workers, + pin_memory=True, ) + if num_workers > 0: + loader_args["persistent_workers"] = True + loader_args["prefetch_factor"] = 2 self.log.debug( - f"dataloader num_workers={loader_args['num_workers']}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " + f"dataloader num_workers={loader_args['num_workers']}, prefetch_factor={loader_args.get('prefetch_factor')}, persistent_workers={loader_args.get('persistent_workers', False)}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " ) self.train_loader = DataLoader( self.train_set, sampler=self.train_sampler, **loader_args From 81a54b5be6a9cb96efc8e79f19aa5fc35609b286 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 09:06:37 -0700 Subject: [PATCH 8/9] make dataloader num_workers user-configurable --- ScaFFold/configs/benchmark_testing.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index aa97106..3086265 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -8,6 +8,7 @@ problem_scale: 6 # Determines dataset resolution and number of unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. +dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: 2 # DistConv param: dimension on which to shard From d6429d3012bed6ea45fc39fdba86f4acf7756608 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 26 Mar 2026 09:20:18 -0700 Subject: [PATCH 9/9] extract warmup to separate method; switch to warming up set number of batches (user configurable) --- ScaFFold/cli.py | 5 + ScaFFold/configs/benchmark_default.yml | 4 +- ScaFFold/configs/benchmark_testing.yml | 2 +- ScaFFold/utils/config_utils.py | 7 +- ScaFFold/utils/trainer.py | 181 ++++++++++++++----------- ScaFFold/worker.py | 6 + 6 files changed, 123 insertions(+), 82 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 588c0cd..585d0ac 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -140,6 +140,11 @@ def main(): benchmark_parser.add_argument( "--batch-size", type=int, nargs="+", help="Batch sizes for each volume size." ) + benchmark_parser.add_argument( + "--warmup-batches", + type=int, + help="Number of warmup batches to run per rank before training.", + ) benchmark_parser.add_argument( "--dataloader-num-workers", type=int, diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 5fa12e8..7e5a26d 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -30,6 +30,6 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_epochs: 1 # How many warmup epochs before training +warmup_batches: 5 # How many warmup batches per rank to run before training. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. -target_dice: 0.95 \ No newline at end of file +target_dice: 0.95 diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 3086265..8f3cf0d 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -30,6 +30,6 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_epochs: 1 # How many warmup epochs before training +warmup_batches: 5 # How many warmup batches per rank to run before training. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index aae85b2..74e933a 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -67,7 +67,12 @@ def __init__(self, config_dict): self.loss_freq = config_dict["loss_freq"] self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] - self.warmup_epochs = config_dict["warmup_epochs"] + warmup_batches = config_dict.get("warmup_batches") + if warmup_batches is None: + legacy_warmup_epochs = config_dict.get("warmup_epochs", 0) + warmup_batches = legacy_warmup_epochs if legacy_warmup_epochs > 0 else 1 + self.warmup_batches = warmup_batches + self.warmup_epochs = config_dict.get("warmup_epochs", 0) self.num_shards = config_dict["num_shards"] self.shard_dim = config_dict["shard_dim"] self.dataset_reuse_enforce_commit_id = config_dict[ diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index b943243..32e1fd1 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -323,92 +323,117 @@ def _truncate_stats_file(self, start_epoch): except Exception as e: self.log.warning(f"Failed to truncate stats file: {e}") - def train(self): - """ - Execute model training - """ - + def prepare_training(self): + """Prepare checkpoints, resume state, and output files before training.""" self.cleanup_or_resume() - warmup_epochs = self.config.warmup_epochs - if warmup_epochs > 0: - begin_code_region("warmup") - # Keep BN/Dropout from changing behavior/statistics - self.model.eval() - start_warmup = time.time() - self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - - ps = getattr(self.config, "_parallel_strategy", None) - - for _ in range(warmup_epochs): - for batch in self.train_loader: - images, true_masks = batch["image"], batch["mask"] - - images = images.to( - device=self.device, - dtype=torch.float32, - memory_format=torch.channels_last_3d, - non_blocking=True, - ) - images_dc = DCTensor.distribute(images, ps) - - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ) + def warmup(self): + """Run warmup iterations before the main training loop.""" + warmup_batches = self.config.warmup_batches + if warmup_batches <= 0: + return - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): - # Forward on DCTensor - masks_pred_dc = self.model(images_dc) - - # Convert predictions for loss - if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: - n_shards = ps.num_shards[0] - else: - n_shards = ps.num_shards - if images.size(0) < n_shards: - # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs - masks_pred = masks_pred_dc.to_replicate() - labels_for_loss = true_masks - else: - # Otherwise, shard labels across batch dim to match to_ddp layout - masks_pred = masks_pred_dc.to_ddp() - dt_labels = distribute_tensor( - true_masks, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], - ) - labels_for_loss = dt_labels.to_local() + ps = getattr(self.config, "_parallel_strategy", None) + if ps is None: + raise RuntimeError( + "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." + ) - CE_loss = self.criterion(masks_pred, labels_for_loss) + if self.config.dist: + self.train_loader.sampler.set_epoch(0) + + # Match the main training path as closely as possible. + self.model.train() + self.optimizer.zero_grad(set_to_none=False) + start_warmup = time.time() + max_batches = min(warmup_batches, len(self.train_loader)) + self.log.info(f"Running {max_batches} warmup batch(es) per rank") + + for batch_idx, batch in enumerate(self.train_loader): + if batch_idx >= max_batches: + break + + images, true_masks = batch["image"], batch["mask"] + + images = images.to( + device=self.device, + dtype=torch.float32, + memory_format=torch.channels_last_3d, + non_blocking=True, + ) + true_masks = true_masks.to( + device=self.device, dtype=torch.long, non_blocking=True + ).contiguous() + + # Replicate batch across dc mesh, shard batch across ddp mesh. + images_dp = distribute_tensor( + images, ps.device_mesh, placements=[Shard(0), Replicate()] + ).to_local() + true_masks_dp = distribute_tensor( + true_masks, + ps.device_mesh, + placements=[Shard(0), Replicate()], + ).to_local() + + with torch.autocast( + self.device.type if self.device.type != "mps" else "cpu", + enabled=self.config.torch_amp, + ): + dcx = DCTensor.distribute(images_dp, ps) + dcy = self.model(dcx) + masks_pred = dcy.to_ddp() + + true_masks_ddp = ( + DTensor.from_local( + true_masks_dp, + device_mesh=ps.device_mesh[f"dc{self.config.shard_dim + 2}"], + placements=[Replicate()], + ) + .redistribute( + device_mesh=ps.device_mesh[f"dc{self.config.shard_dim + 2}"], + placements=[Shard(0)], + ) + .to_local() + ) - # Calculate the train dice loss - masks_pred_softmax = F.softmax(masks_pred, dim=1).float() - true_masks_onehot = ( - F.one_hot(labels_for_loss, self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, - ) - loss = CE_loss + train_dice_curr + if self.config.n_categories + 1 == 1: + loss = self.criterion( + masks_pred.squeeze(1), true_masks_ddp.float() + ) + loss += dice_loss( + F.sigmoid(masks_pred.squeeze(1)), + true_masks_ddp.float(), + multiclass=False, + ) + else: + CE_loss = self.criterion(masks_pred, true_masks_ddp) + masks_pred_softmax = F.softmax(masks_pred, dim=1).float() + true_masks_onehot = ( + F.one_hot(true_masks_ddp, self.config.n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + train_dice_curr = dice_loss( + masks_pred_softmax, + true_masks_onehot, + multiclass=True, + ) + loss = CE_loss + train_dice_curr - # Fine as long as we don't step/update - self.grad_scaler.scale(loss).backward() + # Fine as long as we don't step/update + self.grad_scaler.scale(loss).backward() - # Nuke any accumulated grads so the first real step starts clean - for p in self.model.parameters(): - p.grad = None + # Nuke any accumulated grads so the first real step starts clean + for p in self.model.parameters(): + p.grad = None + if self.config.dist: torch.distributed.barrier() - end_code_region("warmup") - self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + + def train(self): + """ + Execute model training + """ epoch = 1 dice_score_train = 0 diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 431c2eb..6495214 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -224,6 +224,12 @@ def main(kwargs_dict: dict = {}): ranks_per_node = get_local_size() prof_ctx, TORCH_PERF_LOCAL = get_torch_context(ranks_per_node, rank) with prof_ctx as prof: + begin_code_region("prepare_training") + trainer.prepare_training() + end_code_region("prepare_training") + begin_code_region("warmup") + trainer.warmup() + end_code_region("warmup") begin_code_region("train") trainer.train() end_code_region("train")