From 4995d8f299e8a128d1a1db4699d3bdbf97b7190e Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 11:05:16 -0700 Subject: [PATCH 01/25] update config with 3D num_shards and shard_dim --- ScaFFold/configs/benchmark_default.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index fce1042..df8f28f 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -4,13 +4,13 @@ dataset_dir: "datasets" # Directory in which to store and query for d fract_base_dir: "fractals" # Base directory for fractal IFS and instances. n_categories: 5 # Number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. +problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: 2 # DistConv param: dimension on which to shard +num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only From 62f13238bbfb83a99e774c3d5a64fbb78b521ccb Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 11:06:02 -0700 Subject: [PATCH 02/25] update config util to expect 3D num_shards and shard_dim, add helper to deal with 1D inputs --- ScaFFold/utils/config_utils.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 08cb481..b66659e 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -67,14 +67,22 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_epochs = config_dict["warmup_epochs"] - self.num_shards = config_dict["num_shards"] - self.shard_dim = config_dict["shard_dim"] self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] + self.dc_num_shards = _ensure_tuple(config_dict.get("num_shards", 1)) + self.dc_shard_dims = _ensure_tuple(config_dict.get("shard_dim", 2)) + self.dc_total_shards = math.prod(self.dc_num_shards) + # Safety Check: Length mismatch + if len(self.dc_num_shards) != len(self.dc_shard_dims): + raise ValueError( + f"Configuration Mismatch: num_shards {self.dc_num_shards} " + f"must have same length as shard_dim {self.dc_shard_dims}" + ) + class RunConfig(Config): def __init__(self, config_dict): @@ -105,3 +113,19 @@ def load_config(file_path: str, config_type: str): raise ValueError( f"Invalid config type specified: {type}. Must be either 'sweep' or 'run'" ) + +def _ensure_tuple(val): + """ + Ensures the input value is converted to a tuple of integers. + Handles: int, list, tuple, and string representations like "[2,2]" or "2,2". + """ + if val is None: + return (1,) # Default safety + if isinstance(val, (list, tuple)): + return tuple(int(i) for i in val) + if isinstance(val, str): + # Handle cases where user might type literal "(2, 2, 2)" in YAML or "2,2" in CLI + val = val.strip("()[]").split(",") + return tuple(int(i.strip()) for i in val if i.strip()) + # Fallback for single integer + return (int(val),) \ No newline at end of file From 85acd07d6350045a12bb9e1536d9d683a2564f32 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 11:06:53 -0700 Subject: [PATCH 03/25] worker no longer needs to modify distconv params set in config -- just pass them as-is to the ParallelStrategy call --- ScaFFold/worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 431c2eb..d007970 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -180,13 +180,13 @@ def main(kwargs_dict: dict = {}): assert dist.get_world_size() % config.dc_num_shards == 0, ( f"world_size={dist.get_world_size()} must be divisible by dc_num_shards={config.dc_num_shards}" ) - # Select which full-tensor dim to shard: 2 + dc_shard_dim - shard_dim = 2 + int(config.dc_shard_dim) + ps = ParallelStrategy( - num_shards=int(config.dc_num_shards), - shard_dim=shard_dim, + num_shards=config.dc_num_shards, + shard_dim=config.dc_shard_dims, device_type=device.type, ) + model = model.to(device, memory_format=torch.channels_last_3d) # Wrap with DistConvDDP that corrects gradient scaling for dc submesh model = DistConvDDP( From 017377a1388d3f873f0ef3d5819530207e617c29 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 15:09:53 -0700 Subject: [PATCH 04/25] implement multi-dimensional sharding for distconv --- ScaFFold/utils/trainer.py | 338 ++++++++++++++++++++++++-------------- 1 file changed, 214 insertions(+), 124 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 29d6807..cdb6526 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -44,6 +44,57 @@ from ScaFFold.utils.utils import gather_and_print_mem +class SpatialAllReduce(torch.autograd.Function): + """ + Custom Autograd function to sum partial loss components sequentially across + an N-dimensional spatial device mesh. + """ + @staticmethod + def forward(ctx, input, spatial_mesh): + output = input.clone() + + # Sequentially reduce across every spatial dimension in the mesh + for mesh_dim in range(spatial_mesh.ndim): + # Get the 1D ProcessGroup for this specific dimension + pg = spatial_mesh.get_group(mesh_dim) + # Sum the tensor across this dimension + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg) + + return output + + @staticmethod + def backward(ctx, grad_output): + # Pass the identical gradient straight through. Zero communication. + return grad_output, None + +def sharded_dice_loss( + input: torch.Tensor, + target: torch.Tensor, + spatial_mesh, # <-- Update type signature here + epsilon: float = 1e-6, +): + assert input.size() == target.size(), f"Shape mismatch: {input.size()} vs {target.size()}" + assert input.dim() == 5, f"Expected 5D tensor, got {input.dim()}D" + + sum_dim = (-1, -2, -3) + + local_inter = 2.0 * (input * target).sum(dim=sum_dim) + local_sets_sum_raw = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) + + packed = torch.stack([local_inter, local_sets_sum_raw]) + + # Pass the entire mesh into our updated custom autograd function + packed_global = SpatialAllReduce.apply(packed, spatial_mesh) + + global_inter = packed_global[0] + global_sets_sum_raw = packed_global[1] + + global_sets_sum = torch.where(global_sets_sum_raw == 0, global_inter, global_sets_sum_raw) + dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) + + return 1.0 - dice_score.mean() + + class BaseTrainer: """ A class that encapsulates some basic functionality for training our model. @@ -317,6 +368,15 @@ def _truncate_stats_file(self, start_epoch): except Exception as e: self.log.warning(f"Failed to truncate stats file: {e}") + def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): + """Log size of tensor in memory""" + + if verbosity < 2: + return + tensor_memory_bytes = tensor[0].element_size() * tensor[0].nelement() + tensor_memory_gb = tensor_memory_bytes / (1024 ** 3) + self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + def train(self): """ Execute model training @@ -324,78 +384,131 @@ def train(self): self.cleanup_or_resume() + # DistConv ParallelStrategy + ps = getattr(self.config, "_parallel_strategy", None) + if ps is None: + raise RuntimeError( + "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." + ) + warmup_epochs = self.config.warmup_epochs if warmup_epochs > 0: begin_code_region("warmup") # Keep BN/Dropout from changing behavior/statistics - self.model.eval() + self.model.train() start_warmup = time.time() self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - ps = getattr(self.config, "_parallel_strategy", None) + # Get the process group for spatial sharding mesh + spatial_mesh = ps.device_mesh[ps.distconv_dim_names] + + # Get placements for DDP sharding + num_spatial_dims = len(ps.shard_dim) + ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims for _ in range(warmup_epochs): for batch in self.train_loader: + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] + # Move samples and labels to GPU images = images.to( device=self.device, dtype=torch.float32, memory_format=torch.channels_last_3d, - non_blocking=False, + non_blocking=True ) - images_dc = DCTensor.distribute(images, ps) - + self._get_memsize(images, "Original image", config.verbosity) true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True + device=self.device, + dtype=torch.long, + non_blocking=True ) + self._get_memsize(images, "Original label", config.verbosity) + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, ps.device_mesh, placements=ddp_placements + ).to_local() + + true_masks_dp = DTensor.from_local( + true_masks, ps.device_mesh, placements=ddp_placements + ).to_local() + + # Delete source tensors immediately after use to keep memory down + del images, true_masks + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + self._get_memsize(images_dc, "Sharded image", config.verbosity) with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, ): + # Forward on DCTensor masks_pred_dc = self.model(images_dc) - - # Convert predictions for loss - if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: - n_shards = ps.num_shards[0] - else: - n_shards = ps.num_shards - if images.size(0) < n_shards: - # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs - masks_pred = masks_pred_dc.to_replicate() - labels_for_loss = true_masks - else: - # Otherwise, shard labels across batch dim to match to_ddp layout - masks_pred = masks_pred_dc.to_ddp() - dt_labels = distribute_tensor( - true_masks, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], - ) - labels_for_loss = dt_labels.to_local() - - CE_loss = self.criterion(masks_pred, labels_for_loss) - - # Calculate the train dice loss - masks_pred_softmax = F.softmax(masks_pred, dim=1).float() - true_masks_onehot = ( - F.one_hot(labels_for_loss, self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, - ) - loss = CE_loss + train_dice_curr - - # Fine as long as we don't step/update - self.grad_scaler.scale(loss).backward() + self.log.debug(f"forward pass complete") + + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug(f"Local Preds Shape: {local_preds.shape}") + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug(f"Local Labels Shape: {local_labels.shape}") + # Should be something like [1, 128, 128, 64] + + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024 ** 3) + self.log.debug(f"Calculating sharded loss. Mem: {current_mem:.2f} GB.") + + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction='sum') + + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + + global_total_voxels = local_labels.numel() * math.prod(self.config.dc_num_shards) + loss_ce = global_ce_sum / global_total_voxels + + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1) + local_labels_one_hot = F.one_hot(local_labels, num_classes=self.config.n_categories+1).permute(0, 4, 1, 2, 3).float() + + # Pass the spatial_mesh directly + loss_dice = sharded_dice_loss(local_preds_softmax, local_labels_one_hot, spatial_mesh) + + # 3. Combine Loss + loss = loss_ce + loss_dice + + self.log.debug(f"loss calculation complete. Proceeding to backward pass") + + # Backward pass + self.grad_scaler.scale(loss).backward() + self.log.debug(f"backward pass complete. Stepping optimizer") + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del local_preds, local_labels, local_preds_softmax, local_labels_one_hot + del loss_ce, loss_dice, loss, images_dp, true_masks_dp + + if self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug(f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", flush=True) # Nuke any accumulated grads so the first real step starts clean for p in self.model.parameters(): @@ -415,13 +528,6 @@ def train(self): ) break - # DistConv ParallelStrategy - ps = getattr(self.config, "_parallel_strategy", None) - if ps is None: - raise RuntimeError( - "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." - ) - # Timer and tracking variables epoch_start_time = time.time() train_dice_curr = 0 @@ -451,6 +557,7 @@ def train(self): begin_code_region("batch_loop") for batch in self.train_loader: + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] begin_code_region("image_to_device") @@ -466,18 +573,26 @@ def train(self): end_code_region("image_to_device") gather_and_print_mem(self.log, "after_batch_to_device") - # Replicate batch across dc mesh, shard batch across ddp mesh. - # This ensures all dc ranks in the same ddp group see the same samples, - # and ddp ranks see disjoint samples. - images_dp = distribute_tensor( - images, ps.device_mesh, placements=[Shard(0), Replicate()] + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, ps.device_mesh, placements=ddp_placements ).to_local() - true_masks_dp = distribute_tensor( - true_masks, - ps.device_mesh, - placements=[Shard(0), Replicate()], + + true_masks_dp = DTensor.from_local( + true_masks, ps.device_mesh, placements=ddp_placements ).to_local() + # Delete source tensors immediately after use to keep memory down + del images, true_masks + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + self._get_memsize(images_dc, "Sharded image", config.verbosity) + with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, @@ -486,72 +601,47 @@ def train(self): torch.cuda.reset_peak_memory_stats() gather_and_print_mem(self.log, "pre_forward") begin_code_region("predict") - - # Spatially shard the chosen dimension across dc mesh - dcx = DCTensor.distribute(images_dp, ps) - dcy = self.model(dcx) - # Convert back to batch-sharded layout across the dc mesh - masks_pred = dcy.to_ddp() - + masks_pred_dc = self.model(images_dc) end_code_region("predict") gather_and_print_mem(self.log, "post_forward") - # Reshard labels across dc mesh to match masks_pred's batch partition - # Start from dc-replicated labels, then shard batch across dc - true_masks_ddp = ( - DTensor.from_local( - true_masks_dp, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Replicate()], - ) - .redistribute( - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], - ) - .to_local() - ) + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug(f"Local Preds Shape: {local_preds.shape}") + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug(f"Local Labels Shape: {local_labels.shape}") + # Should be something like [1, 128, 128, 64] begin_code_region("calculate_loss") - # Calculate the loss - if self.config.n_categories + 1 == 1: - loss = self.criterion( - masks_pred.squeeze(1), true_masks_ddp.float() - ) - loss += dice_loss( - F.sigmoid(masks_pred.squeeze(1)), - true_masks_ddp.float(), - multiclass=False, - ) - else: - # Calculate the CrossEntropy loss - CE_loss = self.criterion(masks_pred, true_masks_ddp) - - # Calculate the train dice loss - masks_pred_softmax = F.softmax( - masks_pred, dim=1 - ).float() - true_masks_onehot = ( - F.one_hot( - true_masks_ddp, self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, - ) - - # Our loss function is CE loss + dice loss - loss = CE_loss + train_dice_curr - - # Track the train dice loss separately for debugging - train_dice_total += train_dice_curr + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024 ** 3) + self.log.debug(f"Calculating sharded loss. Mem: {current_mem:.2f} GB.") + + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction='sum') + + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + + global_total_voxels = local_labels.numel() * math.prod(self.config.dc_num_shards) + loss_ce = global_ce_sum / global_total_voxels + + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1) + local_labels_one_hot = F.one_hot(local_labels, num_classes=self.config.n_categories+1).permute(0, 4, 1, 2, 3).float() + + # Pass the spatial_mesh directly + loss_dice = sharded_dice_loss(local_preds_softmax, local_labels_one_hot, spatial_mesh) + + # 3. Combine Loss + loss = loss_ce + + train_dice_total += (1.0 - loss_dice.detach().item()) # Note that this is dice *score*, hence the 1 - loss + end_code_region("calculate_loss") gather_and_print_mem(self.log, "pre_backward") @@ -574,7 +664,7 @@ def train(self): # Update the loss begin_code_region("update_loss") - pbar.update(images_dp.shape[0]) + pbar.update(images_dc.shape[0]) self.global_step += 1 batch_step += 1 # Stay on GPU From e706d27dbaf0e852fab6cc1a8b8ce38734fc157a Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 15:59:56 -0700 Subject: [PATCH 05/25] update distconv param name scheme --- ScaFFold/configs/benchmark_default.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index df8f28f..96977dd 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -9,8 +9,8 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only From 34d3b9f5096d82184615e7842a4716d4df54fda7 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 16:00:07 -0700 Subject: [PATCH 06/25] fix loss calc --- ScaFFold/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index cdb6526..2b2a6d2 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -639,7 +639,7 @@ def train(self): loss_dice = sharded_dice_loss(local_preds_softmax, local_labels_one_hot, spatial_mesh) # 3. Combine Loss - loss = loss_ce + + loss = loss_ce + loss_dice train_dice_total += (1.0 - loss_dice.detach().item()) # Note that this is dice *score*, hence the 1 - loss end_code_region("calculate_loss") From f0382567dfbc00e43deef4e0c5d14f5286598bb9 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 20:43:39 -0700 Subject: [PATCH 07/25] add sharded dice loss calculation to dice score util --- ScaFFold/utils/dice_score.py | 49 ++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/ScaFFold/utils/dice_score.py b/ScaFFold/utils/dice_score.py index 6536345..88d4579 100644 --- a/ScaFFold/utils/dice_score.py +++ b/ScaFFold/utils/dice_score.py @@ -59,3 +59,52 @@ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) + + +class SpatialAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input, spatial_mesh): + output = input.clone() + for mesh_dim in range(spatial_mesh.ndim): + pg = spatial_mesh.get_group(mesh_dim) + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg) + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +@annotate() +def compute_sharded_dice( + preds: torch.Tensor, + targets: torch.Tensor, + spatial_mesh, + epsilon: float = 1e-6, +): + """ + Computes the globally sharded Dice score. + Returns the raw score tensor of shape [Batch, Channels]. + """ + assert preds.size() == targets.size(), f"Shape mismatch: {preds.size()} vs {targets.size()}" + assert preds.dim() == 5, f"Expected 5D tensor, got {preds.dim()}D" + + sum_dim = (-1, -2, -3) # D, H, W + + local_inter = 2.0 * (preds * targets).sum(dim=sum_dim) + local_sets_sum_raw = preds.sum(dim=sum_dim) + targets.sum(dim=sum_dim) + + packed = torch.stack([local_inter, local_sets_sum_raw]) + + # Global reduce across spatial mesh + packed_global = SpatialAllReduce.apply(packed, spatial_mesh) + + global_inter = packed_global[0] + global_sets_sum_raw = packed_global[1] + + global_sets_sum = torch.where(global_sets_sum_raw == 0, global_inter, global_sets_sum_raw) + + # Calculate score + dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) + + return dice_score From 713ed64712ac512a0447b53668783d2e6681d9a7 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 20:43:58 -0700 Subject: [PATCH 08/25] update evaluate to use sharded dice loss calc --- ScaFFold/utils/evaluate.py | 116 ++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 67 deletions(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index c2d0672..ab9194a 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -18,10 +18,9 @@ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from tqdm import tqdm -from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff +from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff, SpatialAllReduce, compute_sharded_dice from ScaFFold.utils.perf_measure import annotate - @annotate() @torch.inference_mode() def evaluate( @@ -29,10 +28,11 @@ def evaluate( ): net.eval() num_val_batches = len(dataloader) - dice_score = 0.0 + total_dice_score = 0.0 processed_batches = 0 - # For reference, dc sharding happens on this spatial dim: 2=D, 3=H, 4=W + spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] + if primary: print( f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" @@ -50,92 +50,74 @@ def evaluate( ): image, mask_true = batch["image"], batch["mask"] - # move images and labels to correct device and type image = image.to( device=device, dtype=torch.float32, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) + memory_format=torch.channels_last_3d, ) mask_true = mask_true.to( device=device, dtype=torch.long - ).contiguous() # masks no channels NDHW, but ensure cotinuity. + ).contiguous() + + # Dummy channel dimension [B, 1, D, H, W] + mask_true = mask_true.unsqueeze(1) - # Shard batch across ddp mesh, replicate across dc mesh - image_dp = distribute_tensor( - image, parallel_strategy.device_mesh, placements=[Shard(0), Replicate()] + # DDP Sharding + ddp_placements = [Shard(0)] + [Replicate()] * len(parallel_strategy.shard_dim) + image_dp = DTensor.from_local( + image, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() - mask_true_dp = distribute_tensor( - mask_true, - parallel_strategy.device_mesh, - placements=[Shard(0), Replicate()], + mask_true_dp = DTensor.from_local( + mask_true, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() - # Spatially shard images along the dc mesh and run the model + # DistConv Spatial Sharding dcx = DCTensor.distribute(image_dp, parallel_strategy) + mask_true_dc = DCTensor.distribute(mask_true_dp, parallel_strategy) + + # Forward pass on sharded data dcy = net(dcx) - # Replicate predictions across dc to get full spatial result on each dc rank - mask_pred = dcy.to_replicate() + # Extract underlying local tensors (STAY SHARDED) + local_preds = dcy + local_labels_5d = mask_true_dc + local_labels = local_labels_5d.squeeze(1) - # Use labels that are replicated across dc and sharded across ddp, like predictions - mask_true_ddp = mask_true_dp - - # Skip if this ddp rank has an empty local batch - if mask_pred.size(0) == 0 or mask_true_ddp.size(0) == 0: + # Skip empty batches + if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue - # Loss - CE_loss = criterion(mask_pred, mask_true_ddp) - - # Dice loss - mask_pred_softmax = F.softmax(mask_pred, dim=1).float() - mask_true_onehot = ( - F.one_hot(mask_true_ddp, n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_loss_curr = dice_loss( - mask_pred_softmax, - mask_true_onehot, - multiclass=True, - ) - - # Combined validation loss + # --- 1. Sharded CE Loss --- + local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction='sum') + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + + # Divide by total global voxels to get the mean CE Loss + global_total_voxels = local_labels.numel() * math.prod(parallel_strategy.num_shards) + CE_loss = global_ce_sum / global_total_voxels + + # --- 2. Format Predictions & Labels (Strictly Multiclass) --- + mask_pred_probs = F.softmax(local_preds, dim=1).float() + mask_true_onehot = F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() + + # Dice loss uses probabilities + dice_score_probs = compute_sharded_dice(mask_pred_probs, mask_true_onehot, spatial_mesh) + dice_loss_curr = 1.0 - dice_score_probs.mean() + + # Eval metric (excluding background class 0) + # dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background + batch_dice_score = dice_score_probs[:, 1:].mean() + + # --- Combine and Accumulate --- loss = CE_loss + dice_loss_curr val_loss_epoch += loss.item() + total_dice_score += batch_dice_score.item() processed_batches += 1 - # Dice score - if net.module.n_classes == 1: - assert mask_true_ddp.min() >= 0 and mask_true_ddp.max() <= 1, ( - "True mask indices should be in [0, 1]" - ) - mask_pred_bin = (F.sigmoid(mask_pred) > 0.5).float() - dice_score += dice_coeff( - mask_pred_bin, mask_true_ddp, reduce_batch_first=False - ) - else: - assert ( - mask_true_ddp.min() >= 0 - and mask_true_ddp.max() < net.module.n_classes - ), "True mask indices should be in [0, n_classes]" - mask_pred_processed = F.softmax(mask_pred, dim=1).float() - mask_true_onehot_mc = ( - F.one_hot(mask_true_ddp, net.module.n_classes) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_score += multiclass_dice_coeff( - mask_pred_processed[:, 1:], - mask_true_onehot_mc[:, 1:], - reduce_batch_first=True, - ) - net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) if primary: print( - f"evaluate.py: dice_score={dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" + f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" ) - return dice_score, val_loss_epoch, val_loss_avg, processed_batches + return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches \ No newline at end of file From 692f1a61e0a681c28751536cb6408751470985ee Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 20:44:31 -0700 Subject: [PATCH 09/25] update trainer to use new evaluate; other small fixes/tweaks --- ScaFFold/utils/trainer.py | 93 +++++++++------------------------------ 1 file changed, 21 insertions(+), 72 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 2b2a6d2..0a53a44 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -35,7 +35,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset -from ScaFFold.utils.dice_score import dice_loss +from ScaFFold.utils.dice_score import dice_loss, SpatialAllReduce, compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -44,57 +44,6 @@ from ScaFFold.utils.utils import gather_and_print_mem -class SpatialAllReduce(torch.autograd.Function): - """ - Custom Autograd function to sum partial loss components sequentially across - an N-dimensional spatial device mesh. - """ - @staticmethod - def forward(ctx, input, spatial_mesh): - output = input.clone() - - # Sequentially reduce across every spatial dimension in the mesh - for mesh_dim in range(spatial_mesh.ndim): - # Get the 1D ProcessGroup for this specific dimension - pg = spatial_mesh.get_group(mesh_dim) - # Sum the tensor across this dimension - dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg) - - return output - - @staticmethod - def backward(ctx, grad_output): - # Pass the identical gradient straight through. Zero communication. - return grad_output, None - -def sharded_dice_loss( - input: torch.Tensor, - target: torch.Tensor, - spatial_mesh, # <-- Update type signature here - epsilon: float = 1e-6, -): - assert input.size() == target.size(), f"Shape mismatch: {input.size()} vs {target.size()}" - assert input.dim() == 5, f"Expected 5D tensor, got {input.dim()}D" - - sum_dim = (-1, -2, -3) - - local_inter = 2.0 * (input * target).sum(dim=sum_dim) - local_sets_sum_raw = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) - - packed = torch.stack([local_inter, local_sets_sum_raw]) - - # Pass the entire mesh into our updated custom autograd function - packed_global = SpatialAllReduce.apply(packed, spatial_mesh) - - global_inter = packed_global[0] - global_sets_sum_raw = packed_global[1] - - global_sets_sum = torch.where(global_sets_sum_raw == 0, global_inter, global_sets_sum_raw) - dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) - - return 1.0 - dice_score.mean() - - class BaseTrainer: """ A class that encapsulates some basic functionality for training our model. @@ -390,6 +339,12 @@ def train(self): raise RuntimeError( "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." ) + # Get the process group for spatial sharding mesh + spatial_mesh = ps.device_mesh[ps.distconv_dim_names] + + # Get placements for DDP sharding + num_spatial_dims = len(ps.shard_dim) + ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims warmup_epochs = self.config.warmup_epochs if warmup_epochs > 0: @@ -399,13 +354,6 @@ def train(self): start_warmup = time.time() self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - # Get the process group for spatial sharding mesh - spatial_mesh = ps.device_mesh[ps.distconv_dim_names] - - # Get placements for DDP sharding - num_spatial_dims = len(ps.shard_dim) - ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims - for _ in range(warmup_epochs): for batch in self.train_loader: # Load initial samples and labels @@ -418,13 +366,13 @@ def train(self): memory_format=torch.channels_last_3d, non_blocking=True ) - self._get_memsize(images, "Original image", config.verbosity) + self._get_memsize(images, "Original image", self.config.verbose) true_masks = true_masks.to( device=self.device, dtype=torch.long, non_blocking=True ) - self._get_memsize(images, "Original label", config.verbosity) + self._get_memsize(images, "Original label", self.config.verbose) # Add a dummy channel dimension to get 5D [B, 1, D, H, W] true_masks = true_masks.unsqueeze(1) @@ -444,7 +392,7 @@ def train(self): # Spatial sharding via DistConv images_dc = DCTensor.distribute(images_dp, ps) true_masks_dc = DCTensor.distribute(true_masks_dp, ps) - self._get_memsize(images_dc, "Sharded image", config.verbosity) + self._get_memsize(images_dc, "Sharded image", self.config.verbose) with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", @@ -481,11 +429,10 @@ def train(self): loss_ce = global_ce_sum / global_total_voxels # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1) + local_preds_softmax = F.softmax(local_preds, dim=1).float() local_labels_one_hot = F.one_hot(local_labels, num_classes=self.config.n_categories+1).permute(0, 4, 1, 2, 3).float() - - # Pass the spatial_mesh directly - loss_dice = sharded_dice_loss(local_preds_softmax, local_labels_one_hot, spatial_mesh) + dice_scores = compute_sharded_dice(local_preds_softmax, local_labels_one_hot, spatial_mesh) + loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss loss = loss_ce + loss_dice @@ -591,7 +538,7 @@ def train(self): # Spatial sharding via DistConv images_dc = DCTensor.distribute(images_dp, ps) true_masks_dc = DCTensor.distribute(true_masks_dp, ps) - self._get_memsize(images_dc, "Sharded image", config.verbosity) + self._get_memsize(images_dc, "Sharded image", self.config.verbose) with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", @@ -632,15 +579,16 @@ def train(self): loss_ce = global_ce_sum / global_total_voxels # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1) + local_preds_softmax = F.softmax(local_preds, dim=1).float() local_labels_one_hot = F.one_hot(local_labels, num_classes=self.config.n_categories+1).permute(0, 4, 1, 2, 3).float() - # Pass the spatial_mesh directly - loss_dice = sharded_dice_loss(local_preds_softmax, local_labels_one_hot, spatial_mesh) + # Compute sharded dice using new function + dice_scores = compute_sharded_dice(local_preds_softmax, local_labels_one_hot, spatial_mesh) + loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss loss = loss_ce + loss_dice - train_dice_total += (1.0 - loss_dice.detach().item()) # Note that this is dice *score*, hence the 1 - loss + train_dice_total += dice_scores[:, 1:].mean().item() end_code_region("calculate_loss") @@ -652,6 +600,7 @@ def train(self): begin_code_region("step_and_update") if batch_step + 1 == len(self.train_loader): + self.grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=1.0 ) @@ -757,4 +706,4 @@ def train(self): dice_score_train = val_score epoch += 1 - adiak_value("final_epochs", epoch) + adiak_value("final_epochs", epoch) \ No newline at end of file From 50a565ba01d0d6a8188bfa8f9add574c10704167 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Tue, 10 Mar 2026 20:44:46 -0700 Subject: [PATCH 10/25] fix assert --- ScaFFold/worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index d007970..1d60c62 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -176,9 +176,9 @@ def main(kwargs_dict: dict = {}): ) if config.dist: # DDP + DistConv setup - # Ensure world_size is divisible by dc_num_shards - assert dist.get_world_size() % config.dc_num_shards == 0, ( - f"world_size={dist.get_world_size()} must be divisible by dc_num_shards={config.dc_num_shards}" + # Ensure world_size is divisible by total distconv shards + assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( + f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" ) ps = ParallelStrategy( From 49fffe5007b6d799136a5609323c70d109d77cb4 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 09:03:11 -0700 Subject: [PATCH 11/25] fix naming --- ScaFFold/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 1d60c62..faf6536 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -161,8 +161,8 @@ def main(kwargs_dict: dict = {}): # Initialize model begin_code_region("init_model") - config.dc_num_shards = getattr(config, "dc_num_shards", config.num_shards) - config.dc_shard_dim = getattr(config, "dc_shard_dim", config.shard_dim) + config.dc_num_shards = getattr(config, "dc_num_shards", config.dc_num_shards) + config.dc_shard_dim = getattr(config, "dc_shard_dim", config.dc_shard_dim) log.info( f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dim}" ) From 338d03497efd1304c77a8788224bc8463a51b6e6 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 09:04:56 -0700 Subject: [PATCH 12/25] fix naming --- ScaFFold/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index faf6536..0e8ad3e 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -162,9 +162,9 @@ def main(kwargs_dict: dict = {}): # Initialize model begin_code_region("init_model") config.dc_num_shards = getattr(config, "dc_num_shards", config.dc_num_shards) - config.dc_shard_dim = getattr(config, "dc_shard_dim", config.dc_shard_dim) + config.dc_shard_dims = getattr(config, "dc_shard_dims", config.dc_shard_dims) log.info( - f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dim}" + f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dims}" ) device = get_device() log.info(f"Using device: {device}") From 9d23fdb80edc571752c9e2ecd1100994b033c3ff Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 11:30:58 -0700 Subject: [PATCH 13/25] better default values --- ScaFFold/utils/config_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index b66659e..d5fc990 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -73,8 +73,8 @@ def __init__(self, config_dict): self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] - self.dc_num_shards = _ensure_tuple(config_dict.get("num_shards", 1)) - self.dc_shard_dims = _ensure_tuple(config_dict.get("shard_dim", 2)) + self.dc_num_shards = _ensure_tuple(config_dict.get("dc_num_shards", (1, 1, 1))) + self.dc_shard_dims = _ensure_tuple(config_dict.get("dc_shard_dims", (2, 3, 4))) self.dc_total_shards = math.prod(self.dc_num_shards) # Safety Check: Length mismatch if len(self.dc_num_shards) != len(self.dc_shard_dims): @@ -128,4 +128,4 @@ def _ensure_tuple(val): val = val.strip("()[]").split(",") return tuple(int(i.strip()) for i in val if i.strip()) # Fallback for single integer - return (int(val),) \ No newline at end of file + return (1, 1, int(val),) \ No newline at end of file From 4667f5efc4821b5596dfa259662010393428c7ca Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 11:31:11 -0700 Subject: [PATCH 14/25] missing import --- ScaFFold/utils/dice_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/dice_score.py b/ScaFFold/utils/dice_score.py index 88d4579..45f84f6 100644 --- a/ScaFFold/utils/dice_score.py +++ b/ScaFFold/utils/dice_score.py @@ -14,7 +14,7 @@ import torch from torch import Tensor - +import torch.distributed as dist from ScaFFold.utils.perf_measure import annotate From 6f21c3570bec85d8fde25ca112dec40c49205eed Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 11:32:10 -0700 Subject: [PATCH 15/25] update distconv param names, default vals --- ScaFFold/configs/benchmark_default.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 96977dd..2f889a0 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -4,13 +4,13 @@ dataset_dir: "datasets" # Directory in which to store and query for d fract_base_dir: "fractals" # Base directory for fractal IFS and instances. n_categories: 5 # Number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. +problem_scale: 8 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -dc_shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only From 61030b923db3b5b21816d1545caa6385905646fc Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 11:32:27 -0700 Subject: [PATCH 16/25] use np.prod instead of math.prod --- ScaFFold/utils/evaluate.py | 2 +- ScaFFold/worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index ab9194a..fe47c2e 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -92,7 +92,7 @@ def evaluate( global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) # Divide by total global voxels to get the mean CE Loss - global_total_voxels = local_labels.numel() * math.prod(parallel_strategy.num_shards) + global_total_voxels = local_labels.numel() * np.prod(parallel_strategy.num_shards) CE_loss = global_ce_sum / global_total_voxels # --- 2. Format Predictions & Labels (Strictly Multiclass) --- diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 0e8ad3e..8b76879 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -177,7 +177,7 @@ def main(kwargs_dict: dict = {}): if config.dist: # DDP + DistConv setup # Ensure world_size is divisible by total distconv shards - assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( + assert dist.get_world_size() % np.prod(config.dc_num_shards) == 0, ( f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" ) From 7775742fc78ffb764b4138c08e0873d30a41095e Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 11 Mar 2026 11:34:12 -0700 Subject: [PATCH 17/25] ruff --- ScaFFold/utils/config_utils.py | 9 ++- ScaFFold/utils/dice_score.py | 21 +++-- ScaFFold/utils/evaluate.py | 41 ++++++---- ScaFFold/utils/trainer.py | 140 ++++++++++++++++++++++----------- 4 files changed, 140 insertions(+), 71 deletions(-) diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index d5fc990..b6d77d2 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -114,13 +114,14 @@ def load_config(file_path: str, config_type: str): f"Invalid config type specified: {type}. Must be either 'sweep' or 'run'" ) + def _ensure_tuple(val): """ Ensures the input value is converted to a tuple of integers. Handles: int, list, tuple, and string representations like "[2,2]" or "2,2". """ if val is None: - return (1,) # Default safety + return (1,) # Default safety if isinstance(val, (list, tuple)): return tuple(int(i) for i in val) if isinstance(val, str): @@ -128,4 +129,8 @@ def _ensure_tuple(val): val = val.strip("()[]").split(",") return tuple(int(i.strip()) for i in val if i.strip()) # Fallback for single integer - return (1, 1, int(val),) \ No newline at end of file + return ( + 1, + 1, + int(val), + ) diff --git a/ScaFFold/utils/dice_score.py b/ScaFFold/utils/dice_score.py index 45f84f6..ed60fd4 100644 --- a/ScaFFold/utils/dice_score.py +++ b/ScaFFold/utils/dice_score.py @@ -13,8 +13,9 @@ # SPDX-License-Identifier: (Apache-2.0) import torch -from torch import Tensor import torch.distributed as dist +from torch import Tensor + from ScaFFold.utils.perf_measure import annotate @@ -86,25 +87,29 @@ def compute_sharded_dice( Computes the globally sharded Dice score. Returns the raw score tensor of shape [Batch, Channels]. """ - assert preds.size() == targets.size(), f"Shape mismatch: {preds.size()} vs {targets.size()}" + assert preds.size() == targets.size(), ( + f"Shape mismatch: {preds.size()} vs {targets.size()}" + ) assert preds.dim() == 5, f"Expected 5D tensor, got {preds.dim()}D" - sum_dim = (-1, -2, -3) # D, H, W + sum_dim = (-1, -2, -3) # D, H, W local_inter = 2.0 * (preds * targets).sum(dim=sum_dim) local_sets_sum_raw = preds.sum(dim=sum_dim) + targets.sum(dim=sum_dim) packed = torch.stack([local_inter, local_sets_sum_raw]) - + # Global reduce across spatial mesh packed_global = SpatialAllReduce.apply(packed, spatial_mesh) - + global_inter = packed_global[0] global_sets_sum_raw = packed_global[1] - global_sets_sum = torch.where(global_sets_sum_raw == 0, global_inter, global_sets_sum_raw) - + global_sets_sum = torch.where( + global_sets_sum_raw == 0, global_inter, global_sets_sum_raw + ) + # Calculate score dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) - + return dice_score diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index fe47c2e..6f8da8a 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -18,9 +18,16 @@ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from tqdm import tqdm -from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff, SpatialAllReduce, compute_sharded_dice +from ScaFFold.utils.dice_score import ( + SpatialAllReduce, + compute_sharded_dice, + dice_coeff, + dice_loss, + multiclass_dice_coeff, +) from ScaFFold.utils.perf_measure import annotate + @annotate() @torch.inference_mode() def evaluate( @@ -55,15 +62,15 @@ def evaluate( dtype=torch.float32, memory_format=torch.channels_last_3d, ) - mask_true = mask_true.to( - device=device, dtype=torch.long - ).contiguous() + mask_true = mask_true.to(device=device, dtype=torch.long).contiguous() # Dummy channel dimension [B, 1, D, H, W] mask_true = mask_true.unsqueeze(1) # DDP Sharding - ddp_placements = [Shard(0)] + [Replicate()] * len(parallel_strategy.shard_dim) + ddp_placements = [Shard(0)] + [Replicate()] * len( + parallel_strategy.shard_dim + ) image_dp = DTensor.from_local( image, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() @@ -74,7 +81,7 @@ def evaluate( # DistConv Spatial Sharding dcx = DCTensor.distribute(image_dp, parallel_strategy) mask_true_dc = DCTensor.distribute(mask_true_dp, parallel_strategy) - + # Forward pass on sharded data dcy = net(dcx) @@ -88,21 +95,27 @@ def evaluate( continue # --- 1. Sharded CE Loss --- - local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction='sum') + local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction="sum") global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - + # Divide by total global voxels to get the mean CE Loss - global_total_voxels = local_labels.numel() * np.prod(parallel_strategy.num_shards) + global_total_voxels = local_labels.numel() * np.prod( + parallel_strategy.num_shards + ) CE_loss = global_ce_sum / global_total_voxels # --- 2. Format Predictions & Labels (Strictly Multiclass) --- mask_pred_probs = F.softmax(local_preds, dim=1).float() - mask_true_onehot = F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() - + mask_true_onehot = ( + F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() + ) + # Dice loss uses probabilities - dice_score_probs = compute_sharded_dice(mask_pred_probs, mask_true_onehot, spatial_mesh) + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh + ) dice_loss_curr = 1.0 - dice_score_probs.mean() - + # Eval metric (excluding background class 0) # dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background batch_dice_score = dice_score_probs[:, 1:].mean() @@ -120,4 +133,4 @@ def evaluate( print( f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" ) - return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches \ No newline at end of file + return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 0a53a44..10344bc 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -35,7 +35,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset -from ScaFFold.utils.dice_score import dice_loss, SpatialAllReduce, compute_sharded_dice +from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice, dice_loss from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -318,13 +318,13 @@ def _truncate_stats_file(self, start_epoch): self.log.warning(f"Failed to truncate stats file: {e}") def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): - """Log size of tensor in memory""" + """Log size of tensor in memory""" - if verbosity < 2: - return - tensor_memory_bytes = tensor[0].element_size() * tensor[0].nelement() - tensor_memory_gb = tensor_memory_bytes / (1024 ** 3) - self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + if verbosity < 2: + return + tensor_memory_bytes = tensor[0].element_size() * tensor[0].nelement() + tensor_memory_gb = tensor_memory_bytes / (1024**3) + self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") def train(self): """ @@ -364,13 +364,11 @@ def train(self): device=self.device, dtype=torch.float32, memory_format=torch.channels_last_3d, - non_blocking=True + non_blocking=True, ) self._get_memsize(images, "Original image", self.config.verbose) true_masks = true_masks.to( - device=self.device, - dtype=torch.long, - non_blocking=True + device=self.device, dtype=torch.long, non_blocking=True ) self._get_memsize(images, "Original label", self.config.verbose) @@ -381,7 +379,7 @@ def train(self): images_dp = DTensor.from_local( images, ps.device_mesh, placements=ddp_placements ).to_local() - + true_masks_dp = DTensor.from_local( true_masks, ps.device_mesh, placements=ddp_placements ).to_local() @@ -398,11 +396,10 @@ def train(self): self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, ): - # Forward on DCTensor masks_pred_dc = self.model(images_dc) self.log.debug(f"forward pass complete") - + # Extract the underlying PyTorch local tensors local_preds = masks_pred_dc local_labels_5d = true_masks_dc @@ -410,34 +407,52 @@ def train(self): # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] local_labels = local_labels_5d.squeeze(1) if self.world_rank == 0: - self.log.debug(f"Local Preds Shape: {local_preds.shape}") + self.log.debug(f"Local Preds Shape: {local_preds.shape}") # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 self.log.debug(f"Local Labels Shape: {local_labels.shape}") # Should be something like [1, 128, 128, 64] # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024 ** 3) - self.log.debug(f"Calculating sharded loss. Mem: {current_mem:.2f} GB.") - + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + # 1. Sharded Cross Entropy - local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction='sum') - + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) + # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - - global_total_voxels = local_labels.numel() * math.prod(self.config.dc_num_shards) + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, spatial_mesh + ) + + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) loss_ce = global_ce_sum / global_total_voxels # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = F.one_hot(local_labels, num_classes=self.config.n_categories+1).permute(0, 4, 1, 2, 3).float() - dice_scores = compute_sharded_dice(local_preds_softmax, local_labels_one_hot, spatial_mesh) + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, spatial_mesh + ) loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss loss = loss_ce + loss_dice - self.log.debug(f"loss calculation complete. Proceeding to backward pass") + self.log.debug( + f"loss calculation complete. Proceeding to backward pass" + ) # Backward pass self.grad_scaler.scale(loss).backward() @@ -448,14 +463,22 @@ def train(self): self.optimizer.zero_grad(set_to_none=True) # Free memory aggressively - del images_dc, true_masks_dc, masks_pred_dc - del local_preds, local_labels, local_preds_softmax, local_labels_one_hot + del images_dc, true_masks_dc, masks_pred_dc + del ( + local_preds, + local_labels, + local_preds_softmax, + local_labels_one_hot, + ) del loss_ce, loss_dice, loss, images_dp, true_masks_dp if self.world_rank == 0: peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) - self.log.debug(f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", flush=True) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + flush=True, + ) # Nuke any accumulated grads so the first real step starts clean for p in self.model.parameters(): @@ -527,7 +550,7 @@ def train(self): images_dp = DTensor.from_local( images, ps.device_mesh, placements=ddp_placements ).to_local() - + true_masks_dp = DTensor.from_local( true_masks, ps.device_mesh, placements=ddp_placements ).to_local() @@ -538,7 +561,9 @@ def train(self): # Spatial sharding via DistConv images_dc = DCTensor.distribute(images_dp, ps) true_masks_dc = DCTensor.distribute(true_masks_dp, ps) - self._get_memsize(images_dc, "Sharded image", self.config.verbose) + self._get_memsize( + images_dc, "Sharded image", self.config.verbose + ) with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", @@ -559,37 +584,58 @@ def train(self): # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] local_labels = local_labels_5d.squeeze(1) if self.world_rank == 0: - self.log.debug(f"Local Preds Shape: {local_preds.shape}") + self.log.debug( + f"Local Preds Shape: {local_preds.shape}" + ) # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug(f"Local Labels Shape: {local_labels.shape}") + self.log.debug( + f"Local Labels Shape: {local_labels.shape}" + ) # Should be something like [1, 128, 128, 64] begin_code_region("calculate_loss") # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024 ** 3) - self.log.debug(f"Calculating sharded loss. Mem: {current_mem:.2f} GB.") - + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + # 1. Sharded Cross Entropy - local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction='sum') - + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) + # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - - global_total_voxels = local_labels.numel() * math.prod(self.config.dc_num_shards) + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, spatial_mesh + ) + + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) loss_ce = global_ce_sum / global_total_voxels # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = F.one_hot(local_labels, num_classes=self.config.n_categories+1).permute(0, 4, 1, 2, 3).float() + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot( + local_labels, + num_classes=self.config.n_categories + 1, + ) + .permute(0, 4, 1, 2, 3) + .float() + ) # Compute sharded dice using new function - dice_scores = compute_sharded_dice(local_preds_softmax, local_labels_one_hot, spatial_mesh) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, spatial_mesh + ) loss_dice = 1.0 - dice_scores.mean() # 3. Combine Loss loss = loss_ce + loss_dice train_dice_total += dice_scores[:, 1:].mean().item() - + end_code_region("calculate_loss") gather_and_print_mem(self.log, "pre_backward") @@ -706,4 +752,4 @@ def train(self): dice_score_train = val_score epoch += 1 - adiak_value("final_epochs", epoch) \ No newline at end of file + adiak_value("final_epochs", epoch) From a8a940d38462c0927a9df2bd2db9a28156a5d99a Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 18 Mar 2026 08:45:01 -0700 Subject: [PATCH 18/25] import math --- ScaFFold/utils/evaluate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 6f8da8a..6c74edb 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math import torch import torch.nn.functional as F from distconv import DCTensor @@ -99,7 +100,7 @@ def evaluate( global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) # Divide by total global voxels to get the mean CE Loss - global_total_voxels = local_labels.numel() * np.prod( + global_total_voxels = local_labels.numel() * math.prod( parallel_strategy.num_shards ) CE_loss = global_ce_sum / global_total_voxels From a0e13ad3c45b90024d59c5539438d7bda68346a4 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 18 Mar 2026 08:50:41 -0700 Subject: [PATCH 19/25] warmup logging and timing --- ScaFFold/utils/trainer.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 10344bc..2c3257b 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -355,7 +355,9 @@ def train(self): self.log.info(f"Running {warmup_epochs} warmup epoch(s)") for _ in range(warmup_epochs): - for batch in self.train_loader: + for i, batch in enumerate(self.train_loader): + self.log.debug(f" warmup: batch {i} / {len(self.train_loader)}") + batch_t_start = time.time() # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] @@ -397,8 +399,9 @@ def train(self): enabled=self.config.torch_amp, ): # Forward on DCTensor + self.log.debug(f" warmup: running forward pass") masks_pred_dc = self.model(images_dc) - self.log.debug(f"forward pass complete") + self.log.debug(f" warmup: forward pass complete") # Extract the underlying PyTorch local tensors local_preds = masks_pred_dc @@ -407,15 +410,15 @@ def train(self): # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] local_labels = local_labels_5d.squeeze(1) if self.world_rank == 0: - self.log.debug(f"Local Preds Shape: {local_preds.shape}") + self.log.debug(f" warmup: Local Preds Shape: {local_preds.shape}") # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug(f"Local Labels Shape: {local_labels.shape}") + self.log.debug(f" warmup: Local Labels Shape: {local_labels.shape}") # Should be something like [1, 128, 128, 64] # --- SHARDED LOSS CALCULATION --- current_mem = torch.cuda.memory_allocated() / (1024**3) self.log.debug( - f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." ) # 1. Sharded Cross Entropy @@ -451,16 +454,15 @@ def train(self): loss = loss_ce + loss_dice self.log.debug( - f"loss calculation complete. Proceeding to backward pass" + f" warmup: loss calculation complete. Proceeding to backward pass" ) # Backward pass self.grad_scaler.scale(loss).backward() - self.log.debug(f"backward pass complete. Stepping optimizer") + self.log.debug(f" warmup: backward pass complete. Stepping optimizer") self.grad_scaler.step(self.optimizer) self.grad_scaler.update() - self.optimizer.zero_grad(set_to_none=True) # Free memory aggressively del images_dc, true_masks_dc, masks_pred_dc @@ -477,12 +479,14 @@ def train(self): peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) self.log.debug( f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", - flush=True, ) + batch_t_end = time.time() + self.log.debug(f" warmup: batch {i} completed in {batch_t_end - batch_t_start} seconds") # Nuke any accumulated grads so the first real step starts clean for p in self.model.parameters(): p.grad = None + self.optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") From 8636ade7011947156ab49927b3aa38f78f8ea5b5 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 18 Mar 2026 08:51:07 -0700 Subject: [PATCH 20/25] import math for prod --- ScaFFold/worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 8b76879..516b391 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -16,6 +16,7 @@ import socket import sys import time +import math from argparse import Namespace import numpy as np @@ -177,7 +178,7 @@ def main(kwargs_dict: dict = {}): if config.dist: # DDP + DistConv setup # Ensure world_size is divisible by total distconv shards - assert dist.get_world_size() % np.prod(config.dc_num_shards) == 0, ( + assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" ) From 24b4662b01c6bf46bfd9c3a5495efe468663d457 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 14:24:38 -0700 Subject: [PATCH 21/25] Add missing import --- ScaFFold/utils/evaluate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 6c74edb..fbecf90 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -13,6 +13,8 @@ # SPDX-License-Identifier: (Apache-2.0) import math + +import numpy as np import torch import torch.nn.functional as F from distconv import DCTensor From d4bbd4d51098a0d84336d4ce807745896c323450 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 14:24:53 -0700 Subject: [PATCH 22/25] Remove extra func make cli arg tuple --- ScaFFold/cli.py | 3 ++- ScaFFold/utils/config_utils.py | 25 ++----------------------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 840a7e6..3c73f40 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -157,8 +157,9 @@ def main(): help="Resume execution in this specific directory. Overrides --base-run-dir.", ) benchmark_parser.add_argument( - "--num-shards", + "--dc-num-shards", type=int, + nargs=3, help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", ) benchmark_parser.add_argument( diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index b6d77d2..378dc51 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -73,8 +73,8 @@ def __init__(self, config_dict): self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] - self.dc_num_shards = _ensure_tuple(config_dict.get("dc_num_shards", (1, 1, 1))) - self.dc_shard_dims = _ensure_tuple(config_dict.get("dc_shard_dims", (2, 3, 4))) + self.dc_num_shards = config_dict["dc_num_shards"] + self.dc_shard_dims = config_dict["dc_shard_dims"] self.dc_total_shards = math.prod(self.dc_num_shards) # Safety Check: Length mismatch if len(self.dc_num_shards) != len(self.dc_shard_dims): @@ -113,24 +113,3 @@ def load_config(file_path: str, config_type: str): raise ValueError( f"Invalid config type specified: {type}. Must be either 'sweep' or 'run'" ) - - -def _ensure_tuple(val): - """ - Ensures the input value is converted to a tuple of integers. - Handles: int, list, tuple, and string representations like "[2,2]" or "2,2". - """ - if val is None: - return (1,) # Default safety - if isinstance(val, (list, tuple)): - return tuple(int(i) for i in val) - if isinstance(val, str): - # Handle cases where user might type literal "(2, 2, 2)" in YAML or "2,2" in CLI - val = val.strip("()[]").split(",") - return tuple(int(i.strip()) for i in val if i.strip()) - # Fallback for single integer - return ( - 1, - 1, - int(val), - ) From a78cf58c677ab06f9ba4c6c7db8a2e6ca738ff44 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 14:25:26 -0700 Subject: [PATCH 23/25] lint --- ScaFFold/utils/trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 2c3257b..20f9251 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -410,9 +410,13 @@ def train(self): # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] local_labels = local_labels_5d.squeeze(1) if self.world_rank == 0: - self.log.debug(f" warmup: Local Preds Shape: {local_preds.shape}") + self.log.debug( + f" warmup: Local Preds Shape: {local_preds.shape}" + ) # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug(f" warmup: Local Labels Shape: {local_labels.shape}") + self.log.debug( + f" warmup: Local Labels Shape: {local_labels.shape}" + ) # Should be something like [1, 128, 128, 64] # --- SHARDED LOSS CALCULATION --- @@ -459,7 +463,9 @@ def train(self): # Backward pass self.grad_scaler.scale(loss).backward() - self.log.debug(f" warmup: backward pass complete. Stepping optimizer") + self.log.debug( + f" warmup: backward pass complete. Stepping optimizer" + ) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() @@ -481,7 +487,9 @@ def train(self): f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", ) batch_t_end = time.time() - self.log.debug(f" warmup: batch {i} completed in {batch_t_end - batch_t_start} seconds") + self.log.debug( + f" warmup: batch {i} completed in {batch_t_end - batch_t_start} seconds" + ) # Nuke any accumulated grads so the first real step starts clean for p in self.model.parameters(): From 1761289f4f797c372d7d21b935c76b3fa9456dc6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 14:26:37 -0700 Subject: [PATCH 24/25] lint --- ScaFFold/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 516b391..8bad2ea 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math import os import socket import sys import time -import math from argparse import Namespace import numpy as np From eb434bce09b4b127b4581e1664cb563acc93f81b Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 14:29:07 -0700 Subject: [PATCH 25/25] Update configs --- ScaFFold/configs/benchmark_default.yml | 6 +++--- ScaFFold/configs/benchmark_testing.yml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 2f889a0..e96b103 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -9,9 +9,9 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard -checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index aa97106..5fea4d0 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -9,9 +9,9 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: 2 # DistConv param: dimension on which to shard -checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.