diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 840a7e6..3c73f40 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -157,8 +157,9 @@ def main(): help="Resume execution in this specific directory. Overrides --base-run-dir.", ) benchmark_parser.add_argument( - "--num-shards", + "--dc-num-shards", type=int, + nargs=3, help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", ) benchmark_parser.add_argument( diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index fce1042..e96b103 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -4,14 +4,14 @@ dataset_dir: "datasets" # Directory in which to store and query for d fract_base_dir: "fractals" # Base directory for fractal IFS and instances. n_categories: 5 # Number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. +problem_scale: 8 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: 2 # DistConv param: dimension on which to shard -checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index aa97106..5fea4d0 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -9,9 +9,9 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: 2 # DistConv param: dimension on which to shard -checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 08cb481..378dc51 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -67,14 +67,22 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_epochs = config_dict["warmup_epochs"] - self.num_shards = config_dict["num_shards"] - self.shard_dim = config_dict["shard_dim"] self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] + self.dc_num_shards = config_dict["dc_num_shards"] + self.dc_shard_dims = config_dict["dc_shard_dims"] + self.dc_total_shards = math.prod(self.dc_num_shards) + # Safety Check: Length mismatch + if len(self.dc_num_shards) != len(self.dc_shard_dims): + raise ValueError( + f"Configuration Mismatch: num_shards {self.dc_num_shards} " + f"must have same length as shard_dim {self.dc_shard_dims}" + ) + class RunConfig(Config): def __init__(self, config_dict): diff --git a/ScaFFold/utils/dice_score.py b/ScaFFold/utils/dice_score.py index 6536345..ed60fd4 100644 --- a/ScaFFold/utils/dice_score.py +++ b/ScaFFold/utils/dice_score.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: (Apache-2.0) import torch +import torch.distributed as dist from torch import Tensor from ScaFFold.utils.perf_measure import annotate @@ -59,3 +60,56 @@ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) + + +class SpatialAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input, spatial_mesh): + output = input.clone() + for mesh_dim in range(spatial_mesh.ndim): + pg = spatial_mesh.get_group(mesh_dim) + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg) + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +@annotate() +def compute_sharded_dice( + preds: torch.Tensor, + targets: torch.Tensor, + spatial_mesh, + epsilon: float = 1e-6, +): + """ + Computes the globally sharded Dice score. + Returns the raw score tensor of shape [Batch, Channels]. + """ + assert preds.size() == targets.size(), ( + f"Shape mismatch: {preds.size()} vs {targets.size()}" + ) + assert preds.dim() == 5, f"Expected 5D tensor, got {preds.dim()}D" + + sum_dim = (-1, -2, -3) # D, H, W + + local_inter = 2.0 * (preds * targets).sum(dim=sum_dim) + local_sets_sum_raw = preds.sum(dim=sum_dim) + targets.sum(dim=sum_dim) + + packed = torch.stack([local_inter, local_sets_sum_raw]) + + # Global reduce across spatial mesh + packed_global = SpatialAllReduce.apply(packed, spatial_mesh) + + global_inter = packed_global[0] + global_sets_sum_raw = packed_global[1] + + global_sets_sum = torch.where( + global_sets_sum_raw == 0, global_inter, global_sets_sum_raw + ) + + # Calculate score + dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) + + return dice_score diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index c2d0672..fbecf90 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -12,13 +12,22 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math + +import numpy as np import torch import torch.nn.functional as F from distconv import DCTensor from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from tqdm import tqdm -from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff +from ScaFFold.utils.dice_score import ( + SpatialAllReduce, + compute_sharded_dice, + dice_coeff, + dice_loss, + multiclass_dice_coeff, +) from ScaFFold.utils.perf_measure import annotate @@ -29,10 +38,11 @@ def evaluate( ): net.eval() num_val_batches = len(dataloader) - dice_score = 0.0 + total_dice_score = 0.0 processed_batches = 0 - # For reference, dc sharding happens on this spatial dim: 2=D, 3=H, 4=W + spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] + if primary: print( f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" @@ -50,92 +60,80 @@ def evaluate( ): image, mask_true = batch["image"], batch["mask"] - # move images and labels to correct device and type image = image.to( device=device, dtype=torch.float32, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) + memory_format=torch.channels_last_3d, ) - mask_true = mask_true.to( - device=device, dtype=torch.long - ).contiguous() # masks no channels NDHW, but ensure cotinuity. + mask_true = mask_true.to(device=device, dtype=torch.long).contiguous() + + # Dummy channel dimension [B, 1, D, H, W] + mask_true = mask_true.unsqueeze(1) - # Shard batch across ddp mesh, replicate across dc mesh - image_dp = distribute_tensor( - image, parallel_strategy.device_mesh, placements=[Shard(0), Replicate()] + # DDP Sharding + ddp_placements = [Shard(0)] + [Replicate()] * len( + parallel_strategy.shard_dim + ) + image_dp = DTensor.from_local( + image, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() - mask_true_dp = distribute_tensor( - mask_true, - parallel_strategy.device_mesh, - placements=[Shard(0), Replicate()], + mask_true_dp = DTensor.from_local( + mask_true, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() - # Spatially shard images along the dc mesh and run the model + # DistConv Spatial Sharding dcx = DCTensor.distribute(image_dp, parallel_strategy) - dcy = net(dcx) + mask_true_dc = DCTensor.distribute(mask_true_dp, parallel_strategy) - # Replicate predictions across dc to get full spatial result on each dc rank - mask_pred = dcy.to_replicate() + # Forward pass on sharded data + dcy = net(dcx) - # Use labels that are replicated across dc and sharded across ddp, like predictions - mask_true_ddp = mask_true_dp + # Extract underlying local tensors (STAY SHARDED) + local_preds = dcy + local_labels_5d = mask_true_dc + local_labels = local_labels_5d.squeeze(1) - # Skip if this ddp rank has an empty local batch - if mask_pred.size(0) == 0 or mask_true_ddp.size(0) == 0: + # Skip empty batches + if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue - # Loss - CE_loss = criterion(mask_pred, mask_true_ddp) + # --- 1. Sharded CE Loss --- + local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction="sum") + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - # Dice loss - mask_pred_softmax = F.softmax(mask_pred, dim=1).float() + # Divide by total global voxels to get the mean CE Loss + global_total_voxels = local_labels.numel() * math.prod( + parallel_strategy.num_shards + ) + CE_loss = global_ce_sum / global_total_voxels + + # --- 2. Format Predictions & Labels (Strictly Multiclass) --- + mask_pred_probs = F.softmax(local_preds, dim=1).float() mask_true_onehot = ( - F.one_hot(mask_true_ddp, n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() + F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() ) - dice_loss_curr = dice_loss( - mask_pred_softmax, - mask_true_onehot, - multiclass=True, + + # Dice loss uses probabilities + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh ) + dice_loss_curr = 1.0 - dice_score_probs.mean() - # Combined validation loss + # Eval metric (excluding background class 0) + # dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background + batch_dice_score = dice_score_probs[:, 1:].mean() + + # --- Combine and Accumulate --- loss = CE_loss + dice_loss_curr val_loss_epoch += loss.item() + total_dice_score += batch_dice_score.item() processed_batches += 1 - # Dice score - if net.module.n_classes == 1: - assert mask_true_ddp.min() >= 0 and mask_true_ddp.max() <= 1, ( - "True mask indices should be in [0, 1]" - ) - mask_pred_bin = (F.sigmoid(mask_pred) > 0.5).float() - dice_score += dice_coeff( - mask_pred_bin, mask_true_ddp, reduce_batch_first=False - ) - else: - assert ( - mask_true_ddp.min() >= 0 - and mask_true_ddp.max() < net.module.n_classes - ), "True mask indices should be in [0, n_classes]" - mask_pred_processed = F.softmax(mask_pred, dim=1).float() - mask_true_onehot_mc = ( - F.one_hot(mask_true_ddp, net.module.n_classes) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_score += multiclass_dice_coeff( - mask_pred_processed[:, 1:], - mask_true_onehot_mc[:, 1:], - reduce_batch_first=True, - ) - net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) if primary: print( - f"evaluate.py: dice_score={dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" + f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" ) - return dice_score, val_loss_epoch, val_loss_avg, processed_batches + return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 29d6807..20f9251 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -35,7 +35,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset -from ScaFFold.utils.dice_score import dice_loss +from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice, dice_loss from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -317,6 +317,15 @@ def _truncate_stats_file(self, start_epoch): except Exception as e: self.log.warning(f"Failed to truncate stats file: {e}") + def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): + """Log size of tensor in memory""" + + if verbosity < 2: + return + tensor_memory_bytes = tensor[0].element_size() * tensor[0].nelement() + tensor_memory_gb = tensor_memory_bytes / (1024**3) + self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + def train(self): """ Execute model training @@ -324,82 +333,168 @@ def train(self): self.cleanup_or_resume() + # DistConv ParallelStrategy + ps = getattr(self.config, "_parallel_strategy", None) + if ps is None: + raise RuntimeError( + "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." + ) + # Get the process group for spatial sharding mesh + spatial_mesh = ps.device_mesh[ps.distconv_dim_names] + + # Get placements for DDP sharding + num_spatial_dims = len(ps.shard_dim) + ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims + warmup_epochs = self.config.warmup_epochs if warmup_epochs > 0: begin_code_region("warmup") # Keep BN/Dropout from changing behavior/statistics - self.model.eval() + self.model.train() start_warmup = time.time() self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - ps = getattr(self.config, "_parallel_strategy", None) - for _ in range(warmup_epochs): - for batch in self.train_loader: + for i, batch in enumerate(self.train_loader): + self.log.debug(f" warmup: batch {i} / {len(self.train_loader)}") + batch_t_start = time.time() + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] + # Move samples and labels to GPU images = images.to( device=self.device, dtype=torch.float32, memory_format=torch.channels_last_3d, - non_blocking=False, + non_blocking=True, ) - images_dc = DCTensor.distribute(images, ps) - + self._get_memsize(images, "Original image", self.config.verbose) true_masks = true_masks.to( device=self.device, dtype=torch.long, non_blocking=True ) + self._get_memsize(images, "Original label", self.config.verbose) + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, ps.device_mesh, placements=ddp_placements + ).to_local() + + true_masks_dp = DTensor.from_local( + true_masks, ps.device_mesh, placements=ddp_placements + ).to_local() + + # Delete source tensors immediately after use to keep memory down + del images, true_masks + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + self._get_memsize(images_dc, "Sharded image", self.config.verbose) with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, ): # Forward on DCTensor + self.log.debug(f" warmup: running forward pass") masks_pred_dc = self.model(images_dc) + self.log.debug(f" warmup: forward pass complete") - # Convert predictions for loss - if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: - n_shards = ps.num_shards[0] - else: - n_shards = ps.num_shards - if images.size(0) < n_shards: - # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs - masks_pred = masks_pred_dc.to_replicate() - labels_for_loss = true_masks - else: - # Otherwise, shard labels across batch dim to match to_ddp layout - masks_pred = masks_pred_dc.to_ddp() - dt_labels = distribute_tensor( - true_masks, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug( + f" warmup: Local Preds Shape: {local_preds.shape}" + ) + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug( + f" warmup: Local Labels Shape: {local_labels.shape}" ) - labels_for_loss = dt_labels.to_local() + # Should be something like [1, 128, 128, 64] + + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) - CE_loss = self.criterion(masks_pred, labels_for_loss) + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, spatial_mesh + ) + + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) + loss_ce = global_ce_sum / global_total_voxels - # Calculate the train dice loss - masks_pred_softmax = F.softmax(masks_pred, dim=1).float() - true_masks_onehot = ( - F.one_hot(labels_for_loss, self.config.n_categories + 1) + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) .permute(0, 4, 1, 2, 3) .float() ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, spatial_mesh ) - loss = CE_loss + train_dice_curr + loss_dice = 1.0 - dice_scores.mean() - # Fine as long as we don't step/update - self.grad_scaler.scale(loss).backward() + # 3. Combine Loss + loss = loss_ce + loss_dice + + self.log.debug( + f" warmup: loss calculation complete. Proceeding to backward pass" + ) + + # Backward pass + self.grad_scaler.scale(loss).backward() + self.log.debug( + f" warmup: backward pass complete. Stepping optimizer" + ) + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del ( + local_preds, + local_labels, + local_preds_softmax, + local_labels_one_hot, + ) + del loss_ce, loss_dice, loss, images_dp, true_masks_dp + + if self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + ) + batch_t_end = time.time() + self.log.debug( + f" warmup: batch {i} completed in {batch_t_end - batch_t_start} seconds" + ) # Nuke any accumulated grads so the first real step starts clean for p in self.model.parameters(): p.grad = None + self.optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") @@ -415,13 +510,6 @@ def train(self): ) break - # DistConv ParallelStrategy - ps = getattr(self.config, "_parallel_strategy", None) - if ps is None: - raise RuntimeError( - "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." - ) - # Timer and tracking variables epoch_start_time = time.time() train_dice_curr = 0 @@ -451,6 +539,7 @@ def train(self): begin_code_region("batch_loop") for batch in self.train_loader: + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] begin_code_region("image_to_device") @@ -466,18 +555,28 @@ def train(self): end_code_region("image_to_device") gather_and_print_mem(self.log, "after_batch_to_device") - # Replicate batch across dc mesh, shard batch across ddp mesh. - # This ensures all dc ranks in the same ddp group see the same samples, - # and ddp ranks see disjoint samples. - images_dp = distribute_tensor( - images, ps.device_mesh, placements=[Shard(0), Replicate()] + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, ps.device_mesh, placements=ddp_placements ).to_local() - true_masks_dp = distribute_tensor( - true_masks, - ps.device_mesh, - placements=[Shard(0), Replicate()], + + true_masks_dp = DTensor.from_local( + true_masks, ps.device_mesh, placements=ddp_placements ).to_local() + # Delete source tensors immediately after use to keep memory down + del images, true_masks + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + self._get_memsize( + images_dc, "Sharded image", self.config.verbose + ) + with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, @@ -486,72 +585,69 @@ def train(self): torch.cuda.reset_peak_memory_stats() gather_and_print_mem(self.log, "pre_forward") begin_code_region("predict") - - # Spatially shard the chosen dimension across dc mesh - dcx = DCTensor.distribute(images_dp, ps) - dcy = self.model(dcx) - # Convert back to batch-sharded layout across the dc mesh - masks_pred = dcy.to_ddp() - + masks_pred_dc = self.model(images_dc) end_code_region("predict") gather_and_print_mem(self.log, "post_forward") - # Reshard labels across dc mesh to match masks_pred's batch partition - # Start from dc-replicated labels, then shard batch across dc - true_masks_ddp = ( - DTensor.from_local( - true_masks_dp, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Replicate()], + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug( + f"Local Preds Shape: {local_preds.shape}" ) - .redistribute( - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug( + f"Local Labels Shape: {local_labels.shape}" ) - .to_local() - ) + # Should be something like [1, 128, 128, 64] begin_code_region("calculate_loss") - # Calculate the loss - if self.config.n_categories + 1 == 1: - loss = self.criterion( - masks_pred.squeeze(1), true_masks_ddp.float() - ) - loss += dice_loss( - F.sigmoid(masks_pred.squeeze(1)), - true_masks_ddp.float(), - multiclass=False, - ) - else: - # Calculate the CrossEntropy loss - CE_loss = self.criterion(masks_pred, true_masks_ddp) - - # Calculate the train dice loss - masks_pred_softmax = F.softmax( - masks_pred, dim=1 - ).float() - true_masks_onehot = ( - F.one_hot( - true_masks_ddp, self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) + + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, spatial_mesh + ) + + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) + loss_ce = global_ce_sum / global_total_voxels + + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot( + local_labels, + num_classes=self.config.n_categories + 1, ) + .permute(0, 4, 1, 2, 3) + .float() + ) + + # Compute sharded dice using new function + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, spatial_mesh + ) + loss_dice = 1.0 - dice_scores.mean() - # Our loss function is CE loss + dice loss - loss = CE_loss + train_dice_curr + # 3. Combine Loss + loss = loss_ce + loss_dice + train_dice_total += dice_scores[:, 1:].mean().item() - # Track the train dice loss separately for debugging - train_dice_total += train_dice_curr end_code_region("calculate_loss") gather_and_print_mem(self.log, "pre_backward") @@ -562,6 +658,7 @@ def train(self): begin_code_region("step_and_update") if batch_step + 1 == len(self.train_loader): + self.grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=1.0 ) @@ -574,7 +671,7 @@ def train(self): # Update the loss begin_code_region("update_loss") - pbar.update(images_dp.shape[0]) + pbar.update(images_dc.shape[0]) self.global_step += 1 batch_step += 1 # Stay on GPU diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 431c2eb..8bad2ea 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math import os import socket import sys @@ -161,10 +162,10 @@ def main(kwargs_dict: dict = {}): # Initialize model begin_code_region("init_model") - config.dc_num_shards = getattr(config, "dc_num_shards", config.num_shards) - config.dc_shard_dim = getattr(config, "dc_shard_dim", config.shard_dim) + config.dc_num_shards = getattr(config, "dc_num_shards", config.dc_num_shards) + config.dc_shard_dims = getattr(config, "dc_shard_dims", config.dc_shard_dims) log.info( - f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dim}" + f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dims}" ) device = get_device() log.info(f"Using device: {device}") @@ -176,17 +177,17 @@ def main(kwargs_dict: dict = {}): ) if config.dist: # DDP + DistConv setup - # Ensure world_size is divisible by dc_num_shards - assert dist.get_world_size() % config.dc_num_shards == 0, ( - f"world_size={dist.get_world_size()} must be divisible by dc_num_shards={config.dc_num_shards}" + # Ensure world_size is divisible by total distconv shards + assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( + f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" ) - # Select which full-tensor dim to shard: 2 + dc_shard_dim - shard_dim = 2 + int(config.dc_shard_dim) + ps = ParallelStrategy( - num_shards=int(config.dc_num_shards), - shard_dim=shard_dim, + num_shards=config.dc_num_shards, + shard_dim=config.dc_shard_dims, device_type=device.type, ) + model = model.to(device, memory_format=torch.channels_last_3d) # Wrap with DistConvDDP that corrects gradient scaling for dc submesh model = DistConvDDP(