From e646908ad333c5c3a26179575be939a6ca42c294 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 19 Mar 2026 16:30:16 -0700 Subject: [PATCH 1/3] Annotate warmup for pytorch profiler --- ScaFFold/utils/trainer.py | 139 +++++++++++++++++++------------------- 1 file changed, 70 insertions(+), 69 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 29d6807..098892b 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -326,82 +326,83 @@ def train(self): warmup_epochs = self.config.warmup_epochs if warmup_epochs > 0: - begin_code_region("warmup") - # Keep BN/Dropout from changing behavior/statistics - self.model.eval() - start_warmup = time.time() - self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - - ps = getattr(self.config, "_parallel_strategy", None) - - for _ in range(warmup_epochs): - for batch in self.train_loader: - images, true_masks = batch["image"], batch["mask"] - - images = images.to( - device=self.device, - dtype=torch.float32, - memory_format=torch.channels_last_3d, - non_blocking=False, - ) - images_dc = DCTensor.distribute(images, ps) - - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ) + with torch.profiler.record_function("warmup"): + begin_code_region("warmup") + # Keep BN/Dropout from changing behavior/statistics + self.model.eval() + start_warmup = time.time() + self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): - # Forward on DCTensor - masks_pred_dc = self.model(images_dc) - - # Convert predictions for loss - if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: - n_shards = ps.num_shards[0] - else: - n_shards = ps.num_shards - if images.size(0) < n_shards: - # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs - masks_pred = masks_pred_dc.to_replicate() - labels_for_loss = true_masks - else: - # Otherwise, shard labels across batch dim to match to_ddp layout - masks_pred = masks_pred_dc.to_ddp() - dt_labels = distribute_tensor( - true_masks, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], - ) - labels_for_loss = dt_labels.to_local() + ps = getattr(self.config, "_parallel_strategy", None) - CE_loss = self.criterion(masks_pred, labels_for_loss) + for _ in range(warmup_epochs): + for batch in self.train_loader: + images, true_masks = batch["image"], batch["mask"] - # Calculate the train dice loss - masks_pred_softmax = F.softmax(masks_pred, dim=1).float() - true_masks_onehot = ( - F.one_hot(labels_for_loss, self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() + images = images.to( + device=self.device, + dtype=torch.float32, + memory_format=torch.channels_last_3d, + non_blocking=False, ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, + images_dc = DCTensor.distribute(images, ps) + + true_masks = true_masks.to( + device=self.device, dtype=torch.long, non_blocking=True ) - loss = CE_loss + train_dice_curr - # Fine as long as we don't step/update - self.grad_scaler.scale(loss).backward() + with torch.autocast( + self.device.type if self.device.type != "mps" else "cpu", + enabled=self.config.torch_amp, + ): + # Forward on DCTensor + masks_pred_dc = self.model(images_dc) + + # Convert predictions for loss + if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: + n_shards = ps.num_shards[0] + else: + n_shards = ps.num_shards + if images.size(0) < n_shards: + # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs + masks_pred = masks_pred_dc.to_replicate() + labels_for_loss = true_masks + else: + # Otherwise, shard labels across batch dim to match to_ddp layout + masks_pred = masks_pred_dc.to_ddp() + dt_labels = distribute_tensor( + true_masks, + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], + placements=[Shard(0)], + ) + labels_for_loss = dt_labels.to_local() + + CE_loss = self.criterion(masks_pred, labels_for_loss) + + # Calculate the train dice loss + masks_pred_softmax = F.softmax(masks_pred, dim=1).float() + true_masks_onehot = ( + F.one_hot(labels_for_loss, self.config.n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + train_dice_curr = dice_loss( + masks_pred_softmax, + true_masks_onehot, + multiclass=True, + ) + loss = CE_loss + train_dice_curr + + # Fine as long as we don't step/update + self.grad_scaler.scale(loss).backward() - # Nuke any accumulated grads so the first real step starts clean - for p in self.model.parameters(): - p.grad = None - torch.distributed.barrier() - end_code_region("warmup") + # Nuke any accumulated grads so the first real step starts clean + for p in self.model.parameters(): + p.grad = None + torch.distributed.barrier() + end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") epoch = 1 From a0bbe7dc71b5ae7471fa61a6b9ae01867cef5998 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 19 Mar 2026 16:33:04 -0700 Subject: [PATCH 2/3] Turn off stack --- ScaFFold/utils/perf_measure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/perf_measure.py b/ScaFFold/utils/perf_measure.py index 5af8d5b..efa6452 100644 --- a/ScaFFold/utils/perf_measure.py +++ b/ScaFFold/utils/perf_measure.py @@ -97,7 +97,7 @@ def get_torch_context(ranks_per_node, rank): torchprofile( activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], record_shapes=True, - with_stack=True, + with_stack=False, ) if TORCH_PERF_LOCAL else nullcontext() From 19c4b4a4acc511a3ea04613999abcff62da698fb Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 19 Mar 2026 16:37:53 -0700 Subject: [PATCH 3/3] lint --- ScaFFold/utils/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 098892b..7fc597c 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -359,7 +359,10 @@ def train(self): masks_pred_dc = self.model(images_dc) # Convert predictions for loss - if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: + if ( + isinstance(ps.num_shards, tuple) + and len(ps.num_shards) == 1 + ): n_shards = ps.num_shards[0] else: n_shards = ps.num_shards