Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ScaFFold/utils/perf_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_torch_context(ranks_per_node, rank):
torchprofile(
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
record_shapes=True,
with_stack=True,
with_stack=False,
)
if TORCH_PERF_LOCAL
else nullcontext()
Expand Down
142 changes: 73 additions & 69 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,82 +326,86 @@ def train(self):

warmup_epochs = self.config.warmup_epochs
if warmup_epochs > 0:
begin_code_region("warmup")
# Keep BN/Dropout from changing behavior/statistics
self.model.eval()
start_warmup = time.time()
self.log.info(f"Running {warmup_epochs} warmup epoch(s)")

ps = getattr(self.config, "_parallel_strategy", None)

for _ in range(warmup_epochs):
for batch in self.train_loader:
images, true_masks = batch["image"], batch["mask"]

images = images.to(
device=self.device,
dtype=torch.float32,
memory_format=torch.channels_last_3d,
non_blocking=False,
)
images_dc = DCTensor.distribute(images, ps)

true_masks = true_masks.to(
device=self.device, dtype=torch.long, non_blocking=True
)
with torch.profiler.record_function("warmup"):
begin_code_region("warmup")
# Keep BN/Dropout from changing behavior/statistics
self.model.eval()
start_warmup = time.time()
self.log.info(f"Running {warmup_epochs} warmup epoch(s)")

with torch.autocast(
self.device.type if self.device.type != "mps" else "cpu",
enabled=self.config.torch_amp,
):
# Forward on DCTensor
masks_pred_dc = self.model(images_dc)

# Convert predictions for loss
if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1:
n_shards = ps.num_shards[0]
else:
n_shards = ps.num_shards
if images.size(0) < n_shards:
# For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs
masks_pred = masks_pred_dc.to_replicate()
labels_for_loss = true_masks
else:
# Otherwise, shard labels across batch dim to match to_ddp layout
masks_pred = masks_pred_dc.to_ddp()
dt_labels = distribute_tensor(
true_masks,
device_mesh=ps.device_mesh[
f"dc{self.config.shard_dim + 2}"
],
placements=[Shard(0)],
)
labels_for_loss = dt_labels.to_local()
ps = getattr(self.config, "_parallel_strategy", None)

CE_loss = self.criterion(masks_pred, labels_for_loss)
for _ in range(warmup_epochs):
for batch in self.train_loader:
images, true_masks = batch["image"], batch["mask"]

# Calculate the train dice loss
masks_pred_softmax = F.softmax(masks_pred, dim=1).float()
true_masks_onehot = (
F.one_hot(labels_for_loss, self.config.n_categories + 1)
.permute(0, 4, 1, 2, 3)
.float()
images = images.to(
device=self.device,
dtype=torch.float32,
memory_format=torch.channels_last_3d,
non_blocking=False,
)
train_dice_curr = dice_loss(
masks_pred_softmax,
true_masks_onehot,
multiclass=True,
images_dc = DCTensor.distribute(images, ps)

true_masks = true_masks.to(
device=self.device, dtype=torch.long, non_blocking=True
)
loss = CE_loss + train_dice_curr

# Fine as long as we don't step/update
self.grad_scaler.scale(loss).backward()
with torch.autocast(
self.device.type if self.device.type != "mps" else "cpu",
enabled=self.config.torch_amp,
):
# Forward on DCTensor
masks_pred_dc = self.model(images_dc)

# Convert predictions for loss
if (
isinstance(ps.num_shards, tuple)
and len(ps.num_shards) == 1
):
n_shards = ps.num_shards[0]
else:
n_shards = ps.num_shards
if images.size(0) < n_shards:
# For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs
masks_pred = masks_pred_dc.to_replicate()
labels_for_loss = true_masks
else:
# Otherwise, shard labels across batch dim to match to_ddp layout
masks_pred = masks_pred_dc.to_ddp()
dt_labels = distribute_tensor(
true_masks,
device_mesh=ps.device_mesh[
f"dc{self.config.shard_dim + 2}"
],
placements=[Shard(0)],
)
labels_for_loss = dt_labels.to_local()

CE_loss = self.criterion(masks_pred, labels_for_loss)

# Calculate the train dice loss
masks_pred_softmax = F.softmax(masks_pred, dim=1).float()
true_masks_onehot = (
F.one_hot(labels_for_loss, self.config.n_categories + 1)
.permute(0, 4, 1, 2, 3)
.float()
)
train_dice_curr = dice_loss(
masks_pred_softmax,
true_masks_onehot,
multiclass=True,
)
loss = CE_loss + train_dice_curr

# Fine as long as we don't step/update
self.grad_scaler.scale(loss).backward()

# Nuke any accumulated grads so the first real step starts clean
for p in self.model.parameters():
p.grad = None
torch.distributed.barrier()
end_code_region("warmup")
# Nuke any accumulated grads so the first real step starts clean
for p in self.model.parameters():
p.grad = None
torch.distributed.barrier()
end_code_region("warmup")
self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s")

epoch = 1
Expand Down
Loading