Feat/cross frame fibers#825
Conversation
Empty modules + import-only tests for the cross-frame fiber segmentation dataset. Subsequent commits fill in the affine utilities, the dataset class, and the ps128_fibers training config.
- read_transform_json: loads schema v1.0.0 JSON from local/http/s3 without pulling jsonschema as a new dep. - invert_affine_matrix, get_swap_matrix, matrix_swap_xyz_zyx: minimal numpy subset ported from foundation/volume-registration (no SimpleITK/scipy.optimize/requests). - label_to_image_zyx_matrix: compose inverse + XYZ<->ZYX swap once. - label_patch_image_aabb: project the 8 corners of a label patch through the label->image affine, pad by a margin, clip to image shape. - resample_image_to_label_grid: trilinear resample of the image slab onto the label patch grid (order=1 map_coordinates). - matrix_checksum: short hex digest for cache keys. Tests cover round-trip invert, swap self-inverse, swap equivalence with XYZ point reordering, identity/translation/scale resamples, AABB under shear, clipping, out-of-bounds safety, and a 128^3 resample microbench (~270 ms on this machine; well under the 500 ms budget).
Samples binary fiber patches from two zarr volumes that live in different coordinate frames, bridged by a transform.json sibling of the image: - Opens image + labels zarrs via vesuvius.data.utils.open_zarr, so local, HTTPS, and anonymous-S3 URLs all work out of the box. - Scans the labels volume on a grid of patch-sized cells, keeping positions that satisfy min_labeled_ratio + min_bbox_percent and whose image-frame AABB falls fully inside the image bounds. - Caches the enumerated positions keyed by (URLs, patch size, thresholds, transform checksum) so subsequent runs skip the scan. - __getitem__ reads the label patch natively, resamples the image region trilinearly through inv(M) using scipy map_coordinates, and emits the same sample dict shape as ZarrDataset so the existing augmentation pipeline and loss registry run unchanged. Tests cover FG enumeration, identity/translation image reconstruction, cache round-trip (labels overwritten, index still loads from cache), empty-labels error, and out-of-bounds filtering.
…rs config - BaseTrainer._build_dataset_for_mgr now dispatches on dataset_config.dataset_type. Default remains "zarr" (ZarrDataset); "cross_frame" constructs CrossFrameZarrDataset. Import is lazy so the cross-frame path doesn't pay for scipy.ndimage at module load time. - ConfigManager records dataset_type in _init_attributes and raises early if dataset_type is unknown or if cross_frame is missing any of image_zarr_url / labels_zarr_url / transform_json_url. - New single-task config ps128_fibers.yaml pointing at the PHercParis4 S3 image zarr, the HTTPS fibers labels zarr, and the S3 transform.json. Uses anon S3 access via storage_options.anon=true. Tests: ConfigManager gains three cases covering load-from-YAML, missing URLs, and unknown dataset_type. All 67 training + 71 dataset/utility tests remain green.
Writes a small image + binary label pair plus an identity transform.json to tmp_path, builds a ConfigManager-compatible SimpleNamespace, runs 3 training steps through BaseTrainer, and asserts the loss is finite and strictly decreases. Validates that dataset dispatch, model build, loss, and optimizer all wire up end-to-end for the cross_frame path.
Three tests, marked slow+network and deselected by default via
pytest.ini, that exercise the real endpoints:
- reachability: opens the S3 image zarr, HTTPS labels zarr, and S3
transform.json anonymously; asserts the expected shapes and prints
the XYZ matrix for eyeballing.
- foreground scan: runs the dataset's FG scan against a ~2-row slab of
the labels volume (~1% of the volume) and reports counts + wall time.
- single patch fetch: finds one FG label patch and resamples the
matching image region end-to-end; reports the resample time.
Run with:
uv run pytest -m 'slow and network' \
tests/models/datasets/test_cross_frame_realdata.py -v -s
pytest.ini now registers the `slow` and `network` markers explicitly
and deselects them by default to keep the regular test run hermetic.
|
The latest updates on your projects. Learn more about Vercel for GitHub. 1 Skipped Deployment
|
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 392623fd56
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def valid_patches(self): | ||
| return [ | ||
| {"position": pos, "patch_size": self.patch_size, "volume_name": "fibers"} | ||
| for pos in self._patches |
There was a problem hiding this comment.
Return trainer-compatible patch metadata objects
The trainer’s validation-split logic expects each valid_patches entry to expose attributes (vp.volume_name, vp.position), but this dataset returns plain dicts. In the default training flow, a separate validation dataset instance is always created, so cross-frame data enters that code path and raises AttributeError during dataloader setup instead of starting training.
Useful? React with 👍 / 👎.
| dataset_type = (getattr(mgr, "dataset_config", {}) or {}).get("dataset_type", "zarr") | ||
| if dataset_type == "cross_frame": |
There was a problem hiding this comment.
Normalize dataset_type before dispatching dataset builder
This dispatch reads dataset_config.dataset_type verbatim and compares it case-sensitively, even though ConfigManager already validates a lowercased form. As a result, configs such as dataset_type: Cross_Frame can pass config validation but fail later here with Unknown dataset_type, causing avoidable runtime errors.
Useful? React with 👍 / 👎.
…spatch
Two issues flagged by Codex on the initial cross-frame fibers PR:
P1 - valid_patches returned plain dicts, but BaseTrainer's
_configure_dataloaders leakage-prevention branch iterates valid_patches
and reads vp.volume_name / vp.position as attributes. Return
PatchInfo dataclass instances instead, matching ZarrDataset. Also
expose self.data_path so the trainer's same_source check can route
train+val through the random-split path when both datasets are built
from the same config.
P2 - BaseTrainer._build_dataset_for_mgr compared dataset_type
case-sensitively, so a config with dataset_type: Cross_Frame would
pass ConfigManager validation (which lowercases internally) and then
fail dispatch. Normalize with strip().lower() at the dispatch site.
Adds three regression tests: valid_patches attribute access, data_path
exposure, and case-insensitive dispatch.
CrossFrameZarrDataset now accepts dataset_config.labels_scan_level. When
set, the FG enumeration opens that OME-Zarr level under the labels URL
parent group (strip trailing /<digit>), loads it once into RAM, and
walks a downscaled patch grid to find candidate positions. Full-res
positions = coarse_position * per-axis factor. Each candidate is still
bounds-checked against the image AABB before acceptance. Cache key now
includes scan_level so changing it forces a rebuild.
New config ps192_fibers_dicece.yaml mirrors the srf_2um/ps192_medial
shape (patch 192, batch 2, 14 workers, SGD 0.01, val split 0.85) but:
- dataset_type: cross_frame with the real S3 image, HTTPS labels, and
S3 transform.json
- labels_scan_level: 3 so the initial scan stays in minutes
- nnUNet_DC_and_CE_loss (the repo does not expose a standalone
SkeletonRecall; MedialSurfaceRecall is the combined variant, so per
user instruction we fall back to Dice+CE here)
- wandb_project: fibers / entity: vesuvius-challenge
Tests: test_coarse_scan_level asserts the level-1 scan path finds the
expected FG patch. 250 tests green (was 249).
update_config_from_args auto-detected the data format from --input, which always fails for cross-frame training because --input is just a cache directory and the real data lives at configured URLs. Short-circuit with mgr.data_format = 'zarr' when dataset_type == 'cross_frame' so the launcher does not need --format and does not crash on an empty cache dir.
Applying dataset_config.storage_options={'anon': true} blindly to the
labels zarr broke HTTPS access because aiohttp has no `anon` kwarg. New
helper _storage_options_for drops S3-only keys for http(s) URLs and
leaves local paths alone; explicit storage_options_image /
storage_options_labels continue to override the shared mapping
verbatim. Two regression tests cover the filter and the explicit
override.
Under the auto 'ddp_wrapper' compile policy the 8-GPU launch deadlocked: rank 0 sat at 98% CPU on a Python trace for 5+ minutes while ranks 1-7 parked at the initial DDP barrier (each emitting the Profiler / reducer warnings that fire during compile). Disable torch.compile entirely (compile_policy: off) and pin ddp_find_unused_parameters: false so the reducer stops walking the autograd graph each step. Eager mode is fast enough on H100 for this model size and the run progresses past optimizer init immediately.
Labels now read from /ephemeral/fibers_labels/s1a-fibers-230125-ome.zarr (populated once with rsync -aW from rsync://dl.ash2txt.org/data/other/dev/meshes/s1a-fibers-230125-ome.zarr). Avoids 200ms-per-chunk HTTPS round trips during both the coarse scan and per-patch label reads, and lets the 8 DDP workers sustain a realistic data-loading rate. Image zarr and transform.json remain on S3; only the labels (2.5 GB compressed) live locally.
zarr 2.x's zarr.open rejects storage_options on non-fsspec paths even when the mapping is empty, so pointing labels_zarr_url at a local /ephemeral/... directory blew up with "storage_options passed with non-fsspec path". New _open_zarr_any wrapper dispatches: http(s)/s3 URLs go through vesuvius.data.utils.open_zarr; local paths call zarr.open directly. Image, labels, and coarse-scan opens now all use the wrapper, so any mix of local + remote URLs works.
fsspec-backed zarr handles (S3, HTTPS) don't survive DataLoader worker pickling: aiohttp sessions and the asyncio loop are bound to the parent process. Adopt the same pattern dinovol uses in /home/ubuntu/dino_pretraining/dinovol/dinovol_2/dataset/ssl_zarr_dataset.py: - __init__ probes the zarr once to record shapes and the coarse-scan factor, then discards the handles. No live handles are pickled. - __getstate__ strips _image_array / _labels_array / _scan_array and the PID trackers so cross-process pickling always starts from nothing. - _ensure_process_local_handles() reopens all three arrays lazily on first access, guarded by os.getpid(). Called from _build_patch_index during the main-process FG scan and from every __getitem__ in workers. - After the main-process scan, _close_handles() releases the parent's handles so a fresh fsspec session is created in each worker. Also: stdbuf -oL on both sides of the tee so the training log actually flushes to disk instead of sitting in the pipe buffer. All 12 dataset tests remain green.
Each phase of CrossFrameZarrDataset now prints a [CrossFrameZarrDataset] line with flush=True so operators can see where a launch is blocked (remote probe, transform fetch, coarse scan, cache load, patch build). __getitem__ also prints the first two per-worker calls with per-step timings: handle reopen, label slab read+binarize, affine image resample (AABB and values included), normalizer, and augmentation transforms. Controlled by _DEBUG_CALLS_PER_PID = 2, and the counter resets on __getstate__ so every DataLoader worker reports from its first calls. Cleaner launcher header, plus PYTHONUNBUFFERED=1 for flush reliability.
The old design sampled 256^3 label patches and resampled a huge image AABB (~850^3 = 1.2 GB per sample from S3). At 12.6 s per sample that bottlenecked every launch and would have blown out S3 bandwidth under 8-GPU DDP. Flip the direction: - Enumeration still starts from label FG (via the coarse OME-Zarr level) so patches are only produced where we know fibers exist. Each coarse FG voxel is forward-mapped through inv(M) to an image voxel, snapped to a patch-stride grid in image coords, deduplicated, and bounds- checked against the image shape. Positions stored in the cache are now image-space starts. - __getitem__ reads the 256^3 image patch natively from the image zarr (no interpolation) and resamples the corresponding label region into the image grid with nearest-neighbor map_coordinates (preserves binary labels). The label AABB fetched from disk is ~det(M)^(1/3) smaller than the image patch, so label resample is tens of ms instead of multi-second. New affine helpers mirror the previous ones in the other direction: image_to_label_zyx_matrix, image_patch_label_aabb, resample_label_to_image_grid. Cache key bumped with a sampling_frame discriminator so old caches are invalidated. 27 tests green including translated-transform and coarse-scan cases.
The previous snap 'floor(c - ps/2) // st * st' pushed the resulting patch away from the label-FG voxel, producing spurious patches whose forward-mapped label AABB fell outside the label volume (size=(0, N, M) in one of the debug traces) or missed the FG entirely. Two fixes: - snap each image center 'c' to 'floor(c) // st * st' so the patch [start, start+ps) contains 'c' directly under stride=ps. - after snapping and dedupe, forward-map each candidate's 8 corners through image->label, require a non-degenerate in-bounds label AABB. Filters out the residual rotation/shear edge cases. 27 tests green, including identity coarse-scan-level (now exactly 1 patch, as the label->image mapping dictates).
Two fixes for the hang observed after patch-index build: - valid_patches is built once and cached. Callers like save_train_val_filenames iterate (train_indices + val_indices) and do valid_patches[idx] inside the loop -- a fresh list build per access cost O(N) each and 724k * 724k PatchInfo constructions wedged the main process at 99% CPU for hours. - The corner-verification step in the coarse scan is vectorised across all N candidates in one apply_affine_zyx call instead of a Python loop (same reason). Also: num_dataloader_workers lowered from 14 to 6 (plenty for 1 GPU at ~2.4 s/sample, far less RAM + spawn overhead on 8-GPU DDP). Cache is stripped on __getstate__ so workers rebuild lazily only if needed. 12 tests green.
Mirrors ps256_guided_medial.yaml (the srf_2um guided surface config) but targets fibers on the cross_frame pipeline: - guide_backbone = /ephemeral/outputs_v2_shift005_jitter105_r342500_paris4_20260416/checkpoint_step_352500.pt - guide_supervision_target: fibers, guide_loss_weight: 0.25 - same image/labels/transform URLs as ps256_fibers_dicece - compile_policy: off + ddp_find_unused_parameters: false Runs on GPU 1 in a parallel tmux session while the plain variant continues on GPU 0.
Mirrors ps256_pretrained_dino_pixelshuffle_medial.yaml -- frozen dinovol encoder + pixelshuffle_conv decoder head -- but for the cross_frame fibers pipeline and Dice+CE instead of MedialSurfaceRecall: - pretrained_backbone: /ephemeral/outputs_v2_shift005_jitter105_r342500_paris4_20260416/checkpoint_step_352500.pt - pretrained_decoder_type: pixelshuffle_conv - freeze_encoder: true - dataset_type: cross_frame, labels_scan_level: 3 - loss: nnUNet_DC_and_CE_loss (no medial/skeleton-recall) Replaces the short-lived ps256_guided_fibers_dicece.yaml variant.
Replaced by ps256_pretrained_dino_pixelshuffle_fibers.yaml (frozen dino encoder + pixelshuffle conv head) per user request.
…figs
Extend CrossFrameZarrDataset with valid_label_values (list) and
binarize_labels (default False when valid_label_values is set). The FG
mask used during coarse scan / per-patch verification picks one of three
strategies in priority order:
1. scalar valid_patch_value -> coarse == valid_patch_value
2. list valid_label_values -> np.isin(coarse, valid_label_values)
3. default -> coarse > 0
__getitem__ delivers labels as float32 int-valued tensors when
binarize_labels is False, so multi-class targets like rectoverso
{recto=1, verso=2, intersection=3} can be masked in the loss via
ignore_label instead of being rejected at enumeration time.
New configs:
* ps256_rectoverso_msr.yaml (from-scratch UNet, 2 GPUs, compile on)
* ps256_pretrained_dino_pixelshuffle_rectoverso_msr.yaml (1 GPU)
* ps256_hzvt_msr.yaml (from-scratch UNet, 2 GPUs, compile on)
* ps256_pretrained_dino_pixelshuffle_hzvt_msr.yaml (1 GPU)
All four use MedialSurfaceRecall (DC + SoftSkeletonRecall + CE) with
out_channels=3 and ignore_label=3. Labels rsync'd to /ephemeral/.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…aset MedialSurfaceRecall (= DC_SkelREC_and_CE_loss) requires a precomputed <target>_skel tensor alongside the label. ZarrDataset already handles this via _get_skeleton_targets + MedialSurfaceTransform; CrossFrameZarr wasn't doing it, so the loss crashed with: TypeError: DC_SkelREC_and_CE_loss.forward() missing 1 required positional argument: 'skel' Detect skeleton-aware losses on the target, pass skeleton_targets and skeleton_ignore_values through to create_training_transforms / create_validation_transforms so the augmentation pipeline produces the skel tensor. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
convert_slice_to_bgr + _render_3d_volume_to_bgr now detect multiclass segmentation targets (out_channels > 2, activation softmax-ish) and render them with a stable categorical BGR palette instead of: - misusing the 3-channel RGB path (raw logits → R/G/B mix) - falling through to channel[0] grayscale Ground truth: integer class indices → palette lookup. Prediction: argmax across channel 0 → palette lookup. Palette: 0=black (bg), 1=red, 2=green, 3=grey (conventional ignore slot), classes 4+ cycle through blue/yellow/magenta/cyan/orange/white. Binary out_channels=2 targets unchanged (still use the foreground channel as grayscale). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…bs16 Visualization: Prediction tensors for multiclass targets no longer go through argmax. Instead, apply a numerically stable softmax across channels and lay the foreground probabilities into BGR channels (class 1->R, class 2->G, class 3->B) so uncertain regions show up as dimmer pixels / color blends -- the probability "cloud" is preserved. Background (channel 0) stays black. Ground-truth palette matches: class 1=red, class 2=green, class 3=blue, so GT and predicted colors line up for direct comparison. Configs: Scaled up all three pretrained-dino + pixelshuffle configs that were wasting H100 memory at 256^3 / bs=1: * patch_size 256^3 -> 128^3 * batch_size 1 -> 16 * num_dataloader_workers 6 -> 14 * input_shape updated, model_name renamed to ps128_bs16_* GPU memory went from ~16 GB/run to 32-41 GB/run -- still headroom for more if needed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1bc3287a36
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| mask = np.isin(slab, self.valid_label_values) | ||
| else: | ||
| mask = slab > 0 | ||
| if not mask.any(): | ||
| continue |
There was a problem hiding this comment.
Enforce min_labeled_ratio and min_bbox_percent in scanners
This scan path accepts a candidate as soon as any foreground voxel is present, but it never applies self.min_labeled_ratio or self.min_bbox_percent (and the coarse scanner follows the same pattern). As a result, configs that set stricter thresholds (for example the new ps256_* configs) silently train on patches that should be rejected, which changes class balance and can materially skew experiment outcomes.
Useful? React with 👍 / 👎.
| start = np.floor(center_image - ps_arr / 2.0).astype(np.int64) | ||
| start = (start // st_arr) * st_arr |
There was a problem hiding this comment.
Align full-scan patch starts to contain mapped FG center
In the full-resolution fallback scanner, start positions are computed with floor(center - patch/2) and then snapped to stride, which can move the mapped foreground center outside the selected image patch when the affine is not stride-aligned. That causes valid foreground candidates to be mislocalized or dropped by the subsequent bounds check, so runs without labels_scan_level can miss substantial training signal.
Useful? React with 👍 / 👎.
| out_channels = int(task_cfg.get("out_channels", 0) or 0) | ||
| if out_channels <= 2: | ||
| return False | ||
| activation = str(task_cfg.get("activation", "none") or "none").lower() | ||
| return activation in {"none", "softmax", "identity"} |
There was a problem hiding this comment.
Gate multiclass rendering by task type, not channel count alone
The multiclass detector now classifies any target with out_channels > 2 and activation none/softmax/identity as segmentation, which also matches non-segmentation tasks like normals (out_channels: 3, activation: none). Those tasks are then routed through the softmax-cloud renderer instead of the 3-channel/vector path, producing misleading debug visualizations that hide real model behavior.
Useful? React with 👍 / 👎.
No description provided.