Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def main():
benchmark_parser.add_argument(
"--batch-size", type=int, nargs="+", help="Batch sizes for each volume size."
)
benchmark_parser.add_argument(
"--warmup-batches",
type=int,
help="Number of warmup batches to run per rank before training.",
)
benchmark_parser.add_argument(
"--dataloader-num-workers",
type=int,
help="Number of DataLoader worker processes per rank.",
)
benchmark_parser.add_argument(
"--optimizer",
type=str,
Expand Down
5 changes: 3 additions & 2 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ problem_scale: 6 # Determines dataset resolution and number of
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
dataloader_num_workers: 4 # Number of DataLoader worker processes per rank.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: 2 # DistConv param: dimension on which to shard
Expand All @@ -29,6 +30,6 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_epochs: 1 # How many warmup epochs before training
warmup_batches: 5 # How many warmup batches per rank to run before training.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
target_dice: 0.95
3 changes: 2 additions & 1 deletion ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ problem_scale: 6 # Determines dataset resolution and number of
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
dataloader_num_workers: 4 # Number of DataLoader worker processes per rank.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: 2 # DistConv param: dimension on which to shard
Expand All @@ -29,6 +30,6 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_epochs: 1 # How many warmup epochs before training
warmup_batches: 5 # How many warmup batches per rank to run before training.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
101 changes: 0 additions & 101 deletions ScaFFold/datagen/IteratedFunctionSystem.py

This file was deleted.

9 changes: 8 additions & 1 deletion ScaFFold/datagen/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from ScaFFold.datagen import volumegen

META_FILENAME = "meta.yaml"
DATASET_FORMAT_VERSION = 2
INCLUDE_KEYS = [
"dataset_format_version",
"n_categories",
"n_instances_used_per_fractal",
"problem_scale",
Expand Down Expand Up @@ -116,8 +118,10 @@ def get_dataset(
root.mkdir(exist_ok=True)

# Get dict of required keys and compute config_id
config_dict = vars(config).copy()
config_dict["dataset_format_version"] = DATASET_FORMAT_VERSION
volume_config = _get_required_keys_dict(
config=vars(config), include_keys=INCLUDE_KEYS
config=config_dict, include_keys=INCLUDE_KEYS
)
config_id = _hash_volume_config(volume_config)
commit = _git_commit_short()
Expand All @@ -136,6 +140,8 @@ def get_dataset(
meta = yaml.safe_load(meta_path.read_text())
if meta.get("config_id") != config_id:
continue
if meta.get("dataset_format_version", 1) != DATASET_FORMAT_VERSION:
continue
if require_commit and meta.get("code_commit") != commit:
continue
# If we pass the above checks, this dataset can be reused
Expand Down Expand Up @@ -186,6 +192,7 @@ def get_dataset(
# Write to tmp, then move, so readers never see half-written dataset
meta = {
"config_id": config_id,
"dataset_format_version": DATASET_FORMAT_VERSION,
"config_subset": volume_config,
"include_keys": INCLUDE_KEYS,
"code_commit": commit,
Expand Down
10 changes: 7 additions & 3 deletions ScaFFold/datagen/volumegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(config: Dict):
dtype=np.float32,
)
mask = np.full(
(config.vol_size, config.vol_size, config.vol_size), 0, dtype=np.short
(config.vol_size, config.vol_size, config.vol_size), 0, dtype=np.int64
)

global_vol_idx = curr_vol[0]
Expand Down Expand Up @@ -223,14 +223,18 @@ def main(config: Dict):

# Determine destination folder
subdir = "validation" if global_vol_idx in val_indices else "training"
volume_to_save = np.ascontiguousarray(
volume.transpose((3, 0, 1, 2)), dtype=np.float32
)
mask_to_save = np.ascontiguousarray(mask, dtype=np.int64)

vol_file = os.path.join(vol_path, subdir, f"{global_vol_idx}.npy")
with open(vol_file, "wb") as f:
np.save(f, volume)
np.save(f, volume_to_save)

mask_file = os.path.join(mask_path, subdir, f"{global_vol_idx}_mask.npy")
with open(mask_file, "wb") as f:
np.save(f, mask)
np.save(f, mask_to_save)

end_time = time.time()
total_time = end_time - start_time
Expand Down
10 changes: 8 additions & 2 deletions ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, config_dict):
self.n_instances_used_per_fractal = config_dict["n_instances_used_per_fractal"]
self.scale = 1
self.batch_size = config_dict["batch_size"]
self.dataloader_num_workers = config_dict.get("dataloader_num_workers", 4)
self.epochs = config_dict["epochs"]
self.optimizer = config_dict["optimizer"]
self.disable_scheduler = bool(config_dict["disable_scheduler"])
Expand All @@ -66,7 +67,12 @@ def __init__(self, config_dict):
self.loss_freq = config_dict["loss_freq"]
self.checkpoint_dir = config_dict["checkpoint_dir"]
self.normalize = config_dict["normalize"]
self.warmup_epochs = config_dict["warmup_epochs"]
warmup_batches = config_dict.get("warmup_batches")
if warmup_batches is None:
legacy_warmup_epochs = config_dict.get("warmup_epochs", 0)
warmup_batches = legacy_warmup_epochs if legacy_warmup_epochs > 0 else 1
self.warmup_batches = warmup_batches
self.warmup_epochs = config_dict.get("warmup_epochs", 0)
self.num_shards = config_dict["num_shards"]
self.shard_dim = config_dict["shard_dim"]
self.dataset_reuse_enforce_commit_id = config_dict[
Expand Down Expand Up @@ -103,5 +109,5 @@ def load_config(file_path: str, config_type: str):
return RunConfig(config_dict)
else:
raise ValueError(
f"Invalid config type specified: {type}. Must be either 'sweep' or 'run'"
f"Invalid config_type specified: {type}. Must be either 'sweep' or 'run'"
)
80 changes: 58 additions & 22 deletions ScaFFold/utils/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
import numpy as np
import torch
from torch.utils.data import Dataset
import yaml

from ScaFFold.utils.utils import customlog

DATASET_FORMAT_VERSION = 2
LEGACY_DATASET_FORMAT_VERSION = 1
META_FILENAME = "meta.yaml"


class BasicDataset(Dataset):
def __init__(
Expand All @@ -31,6 +36,8 @@ def __init__(
self.images_dir = Path(images_dir)
self.mask_dir = Path(mask_dir)
self.mask_suffix = mask_suffix
self.dataset_root = self.images_dir.parents[1]
self.dataset_format_version = self._load_dataset_format_version()

self.ids = [
splitext(file)[0]
Expand All @@ -49,25 +56,54 @@ def __init__(
data = pickle.load(data_file)
self.mask_values = data["mask_values"]
customlog(f"Unique mask values: {self.mask_values}")
customlog(f"Dataset format version: {self.dataset_format_version}")

def __len__(self):
return len(self.ids)

@staticmethod
def preprocess(mask_values, img, is_mask):
if is_mask:
mask = np.zeros((img.shape[0], img.shape[1], img.shape[2]), dtype=np.short)
for i, v in enumerate(mask_values):
if img.ndim == 3:
mask[img == v] = i
else:
mask[(img == v).all(-1)] = i
def _load_numpy_array(path):
with open(path, "rb") as handle:
return np.load(handle)

def _load_dataset_format_version(self):
meta_path = self.dataset_root / META_FILENAME
if not meta_path.exists():
return LEGACY_DATASET_FORMAT_VERSION

try:
with open(meta_path, "r") as meta_file:
meta = yaml.safe_load(meta_file) or {}
except Exception as exc:
customlog(
f"Failed to read dataset metadata from {meta_path}: {exc}. Falling back to legacy loader."
)
return LEGACY_DATASET_FORMAT_VERSION

return mask
return int(meta.get("dataset_format_version", LEGACY_DATASET_FORMAT_VERSION))

else:
img = img.transpose((3, 0, 1, 2))
return img
@staticmethod
def _prepare_legacy_image(img):
return np.ascontiguousarray(img.transpose((3, 0, 1, 2)), dtype=np.float32)

@staticmethod
def _prepare_legacy_mask(mask_values, mask):
remapped = np.zeros((mask.shape[0], mask.shape[1], mask.shape[2]), dtype=np.int64)
for i, value in enumerate(mask_values):
if mask.ndim == 3:
remapped[mask == value] = i
else:
remapped[(mask == value).all(-1)] = i

return remapped

@staticmethod
def _prepare_optimized_image(img):
return np.ascontiguousarray(img, dtype=np.float32)

@staticmethod
def _prepare_optimized_mask(mask):
return np.ascontiguousarray(mask, dtype=np.int64)

def __getitem__(self, idx):
name = self.ids[idx]
Expand All @@ -80,19 +116,19 @@ def __getitem__(self, idx):
assert len(mask_file) == 1, (
f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
)
with open(mask_file[0], "rb") as f:
mask = np.load(f)
f.close()
with open(img_file[0], "rb") as f:
img = np.load(f)
f.close()
mask = self._load_numpy_array(mask_file[0])
img = self._load_numpy_array(img_file[0])

img = self.preprocess(self.mask_values, img, is_mask=False)
mask = self.preprocess(self.mask_values, mask, is_mask=True)
if self.dataset_format_version >= DATASET_FORMAT_VERSION:
img = self._prepare_optimized_image(img)
mask = self._prepare_optimized_mask(mask)
else:
img = self._prepare_legacy_image(img)
mask = self._prepare_legacy_mask(self.mask_values, mask)

return {
"image": torch.as_tensor(img.copy()).float().contiguous(),
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
"image": torch.from_numpy(img),
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need .contiguous() here even though volumegen now saves images/masks in contiguous format?

Copy link
Copy Markdown
Collaborator

@michaelmckinsey1 michaelmckinsey1 Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ndryden Do we still need .float().contiguous()?

Copy link
Copy Markdown
Collaborator

@michaelmckinsey1 michaelmckinsey1 Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need

torch.from_numpy(img).contiguous().float()
torch.from_numpy(mask).contiguous().long()

"mask": torch.from_numpy(mask),
}


Expand Down
Loading
Loading