Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ def main():
help="Resume execution in this specific directory. Overrides --base-run-dir.",
)
benchmark_parser.add_argument(
"--num-shards",
"--dc-num-shards",
type=int,
nargs=3,
help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum",
)
benchmark_parser.add_argument(
Expand Down
8 changes: 4 additions & 4 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ dataset_dir: "datasets" # Directory in which to store and query for d
fract_base_dir: "fractals" # Base directory for fractal IFS and instances.
n_categories: 5 # Number of fractal categories present in the dataset.
n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances.
problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6.
problem_scale: 8 # Determines dataset resolution and number of unet layers. Default is 6.
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: 2 # DistConv param: dimension on which to shard
checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems.
dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard
checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
Expand Down
6 changes: 3 additions & 3 deletions ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: 2 # DistConv param: dimension on which to shard
checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems.
num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard
checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
Expand Down
12 changes: 10 additions & 2 deletions ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,22 @@ def __init__(self, config_dict):
self.checkpoint_dir = config_dict["checkpoint_dir"]
self.normalize = config_dict["normalize"]
self.warmup_epochs = config_dict["warmup_epochs"]
self.num_shards = config_dict["num_shards"]
self.shard_dim = config_dict["shard_dim"]
self.dataset_reuse_enforce_commit_id = config_dict[
"dataset_reuse_enforce_commit_id"
]
self.target_dice = config_dict["target_dice"]
self.checkpoint_interval = config_dict["checkpoint_interval"]

self.dc_num_shards = config_dict["dc_num_shards"]
self.dc_shard_dims = config_dict["dc_shard_dims"]
self.dc_total_shards = math.prod(self.dc_num_shards)
# Safety Check: Length mismatch
if len(self.dc_num_shards) != len(self.dc_shard_dims):
raise ValueError(
f"Configuration Mismatch: num_shards {self.dc_num_shards} "
f"must have same length as shard_dim {self.dc_shard_dims}"
)


class RunConfig(Config):
def __init__(self, config_dict):
Expand Down
54 changes: 54 additions & 0 deletions ScaFFold/utils/dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# SPDX-License-Identifier: (Apache-2.0)

import torch
import torch.distributed as dist
from torch import Tensor

from ScaFFold.utils.perf_measure import annotate
Expand Down Expand Up @@ -59,3 +60,56 @@ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# Dice loss (objective to minimize) between 0 and 1
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)


class SpatialAllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input, spatial_mesh):
output = input.clone()
for mesh_dim in range(spatial_mesh.ndim):
pg = spatial_mesh.get_group(mesh_dim)
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg)
return output

@staticmethod
def backward(ctx, grad_output):
return grad_output, None


@annotate()
def compute_sharded_dice(
preds: torch.Tensor,
targets: torch.Tensor,
spatial_mesh,
epsilon: float = 1e-6,
):
"""
Computes the globally sharded Dice score.
Returns the raw score tensor of shape [Batch, Channels].
"""
assert preds.size() == targets.size(), (
f"Shape mismatch: {preds.size()} vs {targets.size()}"
)
assert preds.dim() == 5, f"Expected 5D tensor, got {preds.dim()}D"

sum_dim = (-1, -2, -3) # D, H, W

local_inter = 2.0 * (preds * targets).sum(dim=sum_dim)
local_sets_sum_raw = preds.sum(dim=sum_dim) + targets.sum(dim=sum_dim)

packed = torch.stack([local_inter, local_sets_sum_raw])

# Global reduce across spatial mesh
packed_global = SpatialAllReduce.apply(packed, spatial_mesh)

global_inter = packed_global[0]
global_sets_sum_raw = packed_global[1]

global_sets_sum = torch.where(
global_sets_sum_raw == 0, global_inter, global_sets_sum_raw
)

# Calculate score
dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon)

return dice_score
124 changes: 61 additions & 63 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,22 @@
#
# SPDX-License-Identifier: (Apache-2.0)

import math

import numpy as np
import torch
import torch.nn.functional as F
from distconv import DCTensor
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
from tqdm import tqdm

from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff
from ScaFFold.utils.dice_score import (
SpatialAllReduce,
compute_sharded_dice,
dice_coeff,
dice_loss,
multiclass_dice_coeff,
)
from ScaFFold.utils.perf_measure import annotate


Expand All @@ -29,10 +38,11 @@ def evaluate(
):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0.0
total_dice_score = 0.0
processed_batches = 0

# For reference, dc sharding happens on this spatial dim: 2=D, 3=H, 4=W
spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names]

if primary:
print(
f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}"
Expand All @@ -50,92 +60,80 @@ def evaluate(
):
image, mask_true = batch["image"], batch["mask"]

# move images and labels to correct device and type
image = image.to(
device=device,
dtype=torch.float32,
memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first)
memory_format=torch.channels_last_3d,
)
mask_true = mask_true.to(
device=device, dtype=torch.long
).contiguous() # masks no channels NDHW, but ensure cotinuity.
mask_true = mask_true.to(device=device, dtype=torch.long).contiguous()

# Dummy channel dimension [B, 1, D, H, W]
mask_true = mask_true.unsqueeze(1)

# Shard batch across ddp mesh, replicate across dc mesh
image_dp = distribute_tensor(
image, parallel_strategy.device_mesh, placements=[Shard(0), Replicate()]
# DDP Sharding
ddp_placements = [Shard(0)] + [Replicate()] * len(
parallel_strategy.shard_dim
)
image_dp = DTensor.from_local(
image, parallel_strategy.device_mesh, placements=ddp_placements
).to_local()
mask_true_dp = distribute_tensor(
mask_true,
parallel_strategy.device_mesh,
placements=[Shard(0), Replicate()],
mask_true_dp = DTensor.from_local(
mask_true, parallel_strategy.device_mesh, placements=ddp_placements
).to_local()

# Spatially shard images along the dc mesh and run the model
# DistConv Spatial Sharding
dcx = DCTensor.distribute(image_dp, parallel_strategy)
dcy = net(dcx)
mask_true_dc = DCTensor.distribute(mask_true_dp, parallel_strategy)

# Replicate predictions across dc to get full spatial result on each dc rank
mask_pred = dcy.to_replicate()
# Forward pass on sharded data
dcy = net(dcx)

# Use labels that are replicated across dc and sharded across ddp, like predictions
mask_true_ddp = mask_true_dp
# Extract underlying local tensors (STAY SHARDED)
local_preds = dcy
local_labels_5d = mask_true_dc
local_labels = local_labels_5d.squeeze(1)

# Skip if this ddp rank has an empty local batch
if mask_pred.size(0) == 0 or mask_true_ddp.size(0) == 0:
# Skip empty batches
if local_preds.size(0) == 0 or local_labels.size(0) == 0:
continue

# Loss
CE_loss = criterion(mask_pred, mask_true_ddp)
# --- 1. Sharded CE Loss ---
local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction="sum")
global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh)

# Dice loss
mask_pred_softmax = F.softmax(mask_pred, dim=1).float()
# Divide by total global voxels to get the mean CE Loss
global_total_voxels = local_labels.numel() * math.prod(
parallel_strategy.num_shards
)
CE_loss = global_ce_sum / global_total_voxels

# --- 2. Format Predictions & Labels (Strictly Multiclass) ---
mask_pred_probs = F.softmax(local_preds, dim=1).float()
mask_true_onehot = (
F.one_hot(mask_true_ddp, n_categories + 1)
.permute(0, 4, 1, 2, 3)
.float()
F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float()
)
dice_loss_curr = dice_loss(
mask_pred_softmax,
mask_true_onehot,
multiclass=True,

# Dice loss uses probabilities
dice_score_probs = compute_sharded_dice(
mask_pred_probs, mask_true_onehot, spatial_mesh
)
dice_loss_curr = 1.0 - dice_score_probs.mean()

# Combined validation loss
# Eval metric (excluding background class 0)
# dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background
batch_dice_score = dice_score_probs[:, 1:].mean()

# --- Combine and Accumulate ---
loss = CE_loss + dice_loss_curr
val_loss_epoch += loss.item()
total_dice_score += batch_dice_score.item()
processed_batches += 1

# Dice score
if net.module.n_classes == 1:
assert mask_true_ddp.min() >= 0 and mask_true_ddp.max() <= 1, (
"True mask indices should be in [0, 1]"
)
mask_pred_bin = (F.sigmoid(mask_pred) > 0.5).float()
dice_score += dice_coeff(
mask_pred_bin, mask_true_ddp, reduce_batch_first=False
)
else:
assert (
mask_true_ddp.min() >= 0
and mask_true_ddp.max() < net.module.n_classes
), "True mask indices should be in [0, n_classes]"
mask_pred_processed = F.softmax(mask_pred, dim=1).float()
mask_true_onehot_mc = (
F.one_hot(mask_true_ddp, net.module.n_classes)
.permute(0, 4, 1, 2, 3)
.float()
)
dice_score += multiclass_dice_coeff(
mask_pred_processed[:, 1:],
mask_true_onehot_mc[:, 1:],
reduce_batch_first=True,
)

net.train()

val_loss_avg = val_loss_epoch / max(processed_batches, 1)
if primary:
print(
f"evaluate.py: dice_score={dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}"
f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}"
)
return dice_score, val_loss_epoch, val_loss_avg, processed_batches
return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches
Loading
Loading