Skip to content

Feat/cross frame fibers#825

Draft
giorgioangel wants to merge 26 commits into
mainfrom
feat/cross-frame-fibers
Draft

Feat/cross frame fibers#825
giorgioangel wants to merge 26 commits into
mainfrom
feat/cross-frame-fibers

Conversation

@giorgioangel

Copy link
Copy Markdown
Member

No description provided.

Claude added 6 commits April 20, 2026 19:22
Empty modules + import-only tests for the cross-frame fiber segmentation
dataset. Subsequent commits fill in the affine utilities, the dataset
class, and the ps128_fibers training config.
- read_transform_json: loads schema v1.0.0 JSON from local/http/s3
  without pulling jsonschema as a new dep.
- invert_affine_matrix, get_swap_matrix, matrix_swap_xyz_zyx:
  minimal numpy subset ported from foundation/volume-registration
  (no SimpleITK/scipy.optimize/requests).
- label_to_image_zyx_matrix: compose inverse + XYZ<->ZYX swap once.
- label_patch_image_aabb: project the 8 corners of a label patch
  through the label->image affine, pad by a margin, clip to image shape.
- resample_image_to_label_grid: trilinear resample of the image slab
  onto the label patch grid (order=1 map_coordinates).
- matrix_checksum: short hex digest for cache keys.

Tests cover round-trip invert, swap self-inverse, swap equivalence with
XYZ point reordering, identity/translation/scale resamples, AABB under
shear, clipping, out-of-bounds safety, and a 128^3 resample microbench
(~270 ms on this machine; well under the 500 ms budget).
Samples binary fiber patches from two zarr volumes that live in different
coordinate frames, bridged by a transform.json sibling of the image:

- Opens image + labels zarrs via vesuvius.data.utils.open_zarr, so local,
  HTTPS, and anonymous-S3 URLs all work out of the box.
- Scans the labels volume on a grid of patch-sized cells, keeping
  positions that satisfy min_labeled_ratio + min_bbox_percent and whose
  image-frame AABB falls fully inside the image bounds.
- Caches the enumerated positions keyed by (URLs, patch size, thresholds,
  transform checksum) so subsequent runs skip the scan.
- __getitem__ reads the label patch natively, resamples the image region
  trilinearly through inv(M) using scipy map_coordinates, and emits the
  same sample dict shape as ZarrDataset so the existing augmentation
  pipeline and loss registry run unchanged.

Tests cover FG enumeration, identity/translation image reconstruction,
cache round-trip (labels overwritten, index still loads from cache),
empty-labels error, and out-of-bounds filtering.
…rs config

- BaseTrainer._build_dataset_for_mgr now dispatches on
  dataset_config.dataset_type. Default remains "zarr" (ZarrDataset);
  "cross_frame" constructs CrossFrameZarrDataset. Import is lazy so the
  cross-frame path doesn't pay for scipy.ndimage at module load time.
- ConfigManager records dataset_type in _init_attributes and raises
  early if dataset_type is unknown or if cross_frame is missing any of
  image_zarr_url / labels_zarr_url / transform_json_url.
- New single-task config ps128_fibers.yaml pointing at the PHercParis4
  S3 image zarr, the HTTPS fibers labels zarr, and the S3 transform.json.
  Uses anon S3 access via storage_options.anon=true.

Tests: ConfigManager gains three cases covering load-from-YAML, missing
URLs, and unknown dataset_type. All 67 training + 71 dataset/utility
tests remain green.
Writes a small image + binary label pair plus an identity transform.json
to tmp_path, builds a ConfigManager-compatible SimpleNamespace, runs 3
training steps through BaseTrainer, and asserts the loss is finite and
strictly decreases. Validates that dataset dispatch, model build, loss,
and optimizer all wire up end-to-end for the cross_frame path.
Three tests, marked slow+network and deselected by default via
pytest.ini, that exercise the real endpoints:

- reachability: opens the S3 image zarr, HTTPS labels zarr, and S3
  transform.json anonymously; asserts the expected shapes and prints
  the XYZ matrix for eyeballing.
- foreground scan: runs the dataset's FG scan against a ~2-row slab of
  the labels volume (~1% of the volume) and reports counts + wall time.
- single patch fetch: finds one FG label patch and resamples the
  matching image region end-to-end; reports the resample time.

Run with:
    uv run pytest -m 'slow and network' \
      tests/models/datasets/test_cross_frame_realdata.py -v -s

pytest.ini now registers the `slow` and `network` markers explicitly
and deselects them by default to keep the regular test run hermetic.
@vercel

vercel Bot commented Apr 20, 2026

Copy link
Copy Markdown

The latest updates on your projects. Learn more about Vercel for GitHub.

1 Skipped Deployment
Project Deployment Actions Updated (UTC)
scrollprize-org Ignored Ignored Preview Apr 21, 2026 9:17am

Request Review

@giorgioangel

Copy link
Copy Markdown
Member Author

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 392623fd56

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +317 to +320
def valid_patches(self):
return [
{"position": pos, "patch_size": self.patch_size, "volume_name": "fibers"}
for pos in self._patches

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Return trainer-compatible patch metadata objects

The trainer’s validation-split logic expects each valid_patches entry to expose attributes (vp.volume_name, vp.position), but this dataset returns plain dicts. In the default training flow, a separate validation dataset instance is always created, so cross-frame data enters that code path and raises AttributeError during dataloader setup instead of starting training.

Useful? React with 👍 / 👎.

Comment on lines +401 to +402
dataset_type = (getattr(mgr, "dataset_config", {}) or {}).get("dataset_type", "zarr")
if dataset_type == "cross_frame":

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize dataset_type before dispatching dataset builder

This dispatch reads dataset_config.dataset_type verbatim and compares it case-sensitively, even though ConfigManager already validates a lowercased form. As a result, configs such as dataset_type: Cross_Frame can pass config validation but fail later here with Unknown dataset_type, causing avoidable runtime errors.

Useful? React with 👍 / 👎.

Claude and others added 20 commits April 20, 2026 19:49
…spatch

Two issues flagged by Codex on the initial cross-frame fibers PR:

P1 - valid_patches returned plain dicts, but BaseTrainer's
    _configure_dataloaders leakage-prevention branch iterates valid_patches
    and reads vp.volume_name / vp.position as attributes. Return
    PatchInfo dataclass instances instead, matching ZarrDataset. Also
    expose self.data_path so the trainer's same_source check can route
    train+val through the random-split path when both datasets are built
    from the same config.

P2 - BaseTrainer._build_dataset_for_mgr compared dataset_type
    case-sensitively, so a config with dataset_type: Cross_Frame would
    pass ConfigManager validation (which lowercases internally) and then
    fail dispatch. Normalize with strip().lower() at the dispatch site.

Adds three regression tests: valid_patches attribute access, data_path
exposure, and case-insensitive dispatch.
CrossFrameZarrDataset now accepts dataset_config.labels_scan_level. When
set, the FG enumeration opens that OME-Zarr level under the labels URL
parent group (strip trailing /<digit>), loads it once into RAM, and
walks a downscaled patch grid to find candidate positions. Full-res
positions = coarse_position * per-axis factor. Each candidate is still
bounds-checked against the image AABB before acceptance. Cache key now
includes scan_level so changing it forces a rebuild.

New config ps192_fibers_dicece.yaml mirrors the srf_2um/ps192_medial
shape (patch 192, batch 2, 14 workers, SGD 0.01, val split 0.85) but:
  - dataset_type: cross_frame with the real S3 image, HTTPS labels, and
    S3 transform.json
  - labels_scan_level: 3 so the initial scan stays in minutes
  - nnUNet_DC_and_CE_loss (the repo does not expose a standalone
    SkeletonRecall; MedialSurfaceRecall is the combined variant, so per
    user instruction we fall back to Dice+CE here)
  - wandb_project: fibers / entity: vesuvius-challenge

Tests: test_coarse_scan_level asserts the level-1 scan path finds the
expected FG patch. 250 tests green (was 249).
update_config_from_args auto-detected the data format from --input, which
always fails for cross-frame training because --input is just a cache
directory and the real data lives at configured URLs. Short-circuit with
mgr.data_format = 'zarr' when dataset_type == 'cross_frame' so the
launcher does not need --format and does not crash on an empty cache dir.
Applying dataset_config.storage_options={'anon': true} blindly to the
labels zarr broke HTTPS access because aiohttp has no `anon` kwarg. New
helper _storage_options_for drops S3-only keys for http(s) URLs and
leaves local paths alone; explicit storage_options_image /
storage_options_labels continue to override the shared mapping
verbatim. Two regression tests cover the filter and the explicit
override.
Under the auto 'ddp_wrapper' compile policy the 8-GPU launch deadlocked:
rank 0 sat at 98% CPU on a Python trace for 5+ minutes while ranks 1-7
parked at the initial DDP barrier (each emitting the Profiler / reducer
warnings that fire during compile). Disable torch.compile entirely
(compile_policy: off) and pin ddp_find_unused_parameters: false so the
reducer stops walking the autograd graph each step. Eager mode is fast
enough on H100 for this model size and the run progresses past optimizer
init immediately.
Labels now read from /ephemeral/fibers_labels/s1a-fibers-230125-ome.zarr
(populated once with rsync -aW from
rsync://dl.ash2txt.org/data/other/dev/meshes/s1a-fibers-230125-ome.zarr).
Avoids 200ms-per-chunk HTTPS round trips during both the coarse scan and
per-patch label reads, and lets the 8 DDP workers sustain a realistic
data-loading rate. Image zarr and transform.json remain on S3; only the
labels (2.5 GB compressed) live locally.
zarr 2.x's zarr.open rejects storage_options on non-fsspec paths even
when the mapping is empty, so pointing labels_zarr_url at a local
/ephemeral/... directory blew up with "storage_options passed with
non-fsspec path". New _open_zarr_any wrapper dispatches: http(s)/s3
URLs go through vesuvius.data.utils.open_zarr; local paths call
zarr.open directly. Image, labels, and coarse-scan opens now all use
the wrapper, so any mix of local + remote URLs works.
fsspec-backed zarr handles (S3, HTTPS) don't survive DataLoader worker
pickling: aiohttp sessions and the asyncio loop are bound to the parent
process. Adopt the same pattern dinovol uses in
/home/ubuntu/dino_pretraining/dinovol/dinovol_2/dataset/ssl_zarr_dataset.py:

- __init__ probes the zarr once to record shapes and the coarse-scan
  factor, then discards the handles. No live handles are pickled.
- __getstate__ strips _image_array / _labels_array / _scan_array and the
  PID trackers so cross-process pickling always starts from nothing.
- _ensure_process_local_handles() reopens all three arrays lazily on
  first access, guarded by os.getpid(). Called from _build_patch_index
  during the main-process FG scan and from every __getitem__ in workers.
- After the main-process scan, _close_handles() releases the parent's
  handles so a fresh fsspec session is created in each worker.

Also: stdbuf -oL on both sides of the tee so the training log actually
flushes to disk instead of sitting in the pipe buffer. All 12 dataset
tests remain green.
Each phase of CrossFrameZarrDataset now prints a [CrossFrameZarrDataset]
line with flush=True so operators can see where a launch is blocked
(remote probe, transform fetch, coarse scan, cache load, patch build).

__getitem__ also prints the first two per-worker calls with per-step
timings: handle reopen, label slab read+binarize, affine image resample
(AABB and values included), normalizer, and augmentation transforms.
Controlled by _DEBUG_CALLS_PER_PID = 2, and the counter resets on
__getstate__ so every DataLoader worker reports from its first calls.
Cleaner launcher header, plus PYTHONUNBUFFERED=1 for flush reliability.
The old design sampled 256^3 label patches and resampled a huge image
AABB (~850^3 = 1.2 GB per sample from S3). At 12.6 s per sample that
bottlenecked every launch and would have blown out S3 bandwidth under
8-GPU DDP.

Flip the direction:
- Enumeration still starts from label FG (via the coarse OME-Zarr level)
  so patches are only produced where we know fibers exist. Each coarse
  FG voxel is forward-mapped through inv(M) to an image voxel, snapped
  to a patch-stride grid in image coords, deduplicated, and bounds-
  checked against the image shape. Positions stored in the cache are
  now image-space starts.
- __getitem__ reads the 256^3 image patch natively from the image zarr
  (no interpolation) and resamples the corresponding label region into
  the image grid with nearest-neighbor map_coordinates (preserves
  binary labels). The label AABB fetched from disk is ~det(M)^(1/3)
  smaller than the image patch, so label resample is tens of ms
  instead of multi-second.

New affine helpers mirror the previous ones in the other direction:
image_to_label_zyx_matrix, image_patch_label_aabb,
resample_label_to_image_grid. Cache key bumped with a sampling_frame
discriminator so old caches are invalidated.

27 tests green including translated-transform and coarse-scan cases.
The previous snap 'floor(c - ps/2) // st * st' pushed the resulting
patch away from the label-FG voxel, producing spurious patches whose
forward-mapped label AABB fell outside the label volume (size=(0, N, M)
in one of the debug traces) or missed the FG entirely.

Two fixes:
- snap each image center 'c' to 'floor(c) // st * st' so the patch
  [start, start+ps) contains 'c' directly under stride=ps.
- after snapping and dedupe, forward-map each candidate's 8 corners
  through image->label, require a non-degenerate in-bounds label AABB.
  Filters out the residual rotation/shear edge cases.

27 tests green, including identity coarse-scan-level (now exactly 1
patch, as the label->image mapping dictates).
Two fixes for the hang observed after patch-index build:

- valid_patches is built once and cached. Callers like
  save_train_val_filenames iterate (train_indices + val_indices) and do
  valid_patches[idx] inside the loop -- a fresh list build per access
  cost O(N) each and 724k * 724k PatchInfo constructions wedged the
  main process at 99% CPU for hours.
- The corner-verification step in the coarse scan is vectorised across
  all N candidates in one apply_affine_zyx call instead of a Python
  loop (same reason).

Also: num_dataloader_workers lowered from 14 to 6 (plenty for 1 GPU
at ~2.4 s/sample, far less RAM + spawn overhead on 8-GPU DDP). Cache
is stripped on __getstate__ so workers rebuild lazily only if needed.

12 tests green.
Mirrors ps256_guided_medial.yaml (the srf_2um guided surface config) but
targets fibers on the cross_frame pipeline:
- guide_backbone = /ephemeral/outputs_v2_shift005_jitter105_r342500_paris4_20260416/checkpoint_step_352500.pt
- guide_supervision_target: fibers, guide_loss_weight: 0.25
- same image/labels/transform URLs as ps256_fibers_dicece
- compile_policy: off + ddp_find_unused_parameters: false

Runs on GPU 1 in a parallel tmux session while the plain variant
continues on GPU 0.
Mirrors ps256_pretrained_dino_pixelshuffle_medial.yaml -- frozen dinovol
encoder + pixelshuffle_conv decoder head -- but for the cross_frame
fibers pipeline and Dice+CE instead of MedialSurfaceRecall:
- pretrained_backbone: /ephemeral/outputs_v2_shift005_jitter105_r342500_paris4_20260416/checkpoint_step_352500.pt
- pretrained_decoder_type: pixelshuffle_conv
- freeze_encoder: true
- dataset_type: cross_frame, labels_scan_level: 3
- loss: nnUNet_DC_and_CE_loss (no medial/skeleton-recall)

Replaces the short-lived ps256_guided_fibers_dicece.yaml variant.
Replaced by ps256_pretrained_dino_pixelshuffle_fibers.yaml (frozen dino
encoder + pixelshuffle conv head) per user request.
…figs

Extend CrossFrameZarrDataset with valid_label_values (list) and
binarize_labels (default False when valid_label_values is set). The FG
mask used during coarse scan / per-patch verification picks one of three
strategies in priority order:

  1. scalar valid_patch_value  -> coarse == valid_patch_value
  2. list valid_label_values    -> np.isin(coarse, valid_label_values)
  3. default                    -> coarse > 0

__getitem__ delivers labels as float32 int-valued tensors when
binarize_labels is False, so multi-class targets like rectoverso
{recto=1, verso=2, intersection=3} can be masked in the loss via
ignore_label instead of being rejected at enumeration time.

New configs:
  * ps256_rectoverso_msr.yaml (from-scratch UNet, 2 GPUs, compile on)
  * ps256_pretrained_dino_pixelshuffle_rectoverso_msr.yaml (1 GPU)
  * ps256_hzvt_msr.yaml (from-scratch UNet, 2 GPUs, compile on)
  * ps256_pretrained_dino_pixelshuffle_hzvt_msr.yaml (1 GPU)

All four use MedialSurfaceRecall (DC + SoftSkeletonRecall + CE) with
out_channels=3 and ignore_label=3. Labels rsync'd to /ephemeral/.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…aset

MedialSurfaceRecall (= DC_SkelREC_and_CE_loss) requires a precomputed
<target>_skel tensor alongside the label. ZarrDataset already handles
this via _get_skeleton_targets + MedialSurfaceTransform; CrossFrameZarr
wasn't doing it, so the loss crashed with:

  TypeError: DC_SkelREC_and_CE_loss.forward() missing 1 required positional argument: 'skel'

Detect skeleton-aware losses on the target, pass skeleton_targets and
skeleton_ignore_values through to create_training_transforms /
create_validation_transforms so the augmentation pipeline produces the
skel tensor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
convert_slice_to_bgr + _render_3d_volume_to_bgr now detect multiclass
segmentation targets (out_channels > 2, activation softmax-ish) and
render them with a stable categorical BGR palette instead of:
  - misusing the 3-channel RGB path (raw logits → R/G/B mix)
  - falling through to channel[0] grayscale

Ground truth: integer class indices → palette lookup.
Prediction: argmax across channel 0 → palette lookup.

Palette: 0=black (bg), 1=red, 2=green, 3=grey (conventional ignore slot),
classes 4+ cycle through blue/yellow/magenta/cyan/orange/white.

Binary out_channels=2 targets unchanged (still use the foreground
channel as grayscale).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…bs16

Visualization:
Prediction tensors for multiclass targets no longer go through argmax.
Instead, apply a numerically stable softmax across channels and lay the
foreground probabilities into BGR channels (class 1->R, class 2->G,
class 3->B) so uncertain regions show up as dimmer pixels / color
blends -- the probability "cloud" is preserved. Background (channel 0)
stays black.

Ground-truth palette matches: class 1=red, class 2=green, class 3=blue,
so GT and predicted colors line up for direct comparison.

Configs:
Scaled up all three pretrained-dino + pixelshuffle configs that were
wasting H100 memory at 256^3 / bs=1:

  * patch_size 256^3 -> 128^3
  * batch_size 1 -> 16
  * num_dataloader_workers 6 -> 14
  * input_shape updated, model_name renamed to ps128_bs16_*

GPU memory went from ~16 GB/run to 32-41 GB/run -- still headroom for
more if needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@giorgioangel

Copy link
Copy Markdown
Member Author

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1bc3287a36

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +559 to +563
mask = np.isin(slab, self.valid_label_values)
else:
mask = slab > 0
if not mask.any():
continue

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Enforce min_labeled_ratio and min_bbox_percent in scanners

This scan path accepts a candidate as soon as any foreground voxel is present, but it never applies self.min_labeled_ratio or self.min_bbox_percent (and the coarse scanner follows the same pattern). As a result, configs that set stricter thresholds (for example the new ps256_* configs) silently train on patches that should be rejected, which changes class balance and can materially skew experiment outcomes.

Useful? React with 👍 / 👎.

Comment on lines +571 to +572
start = np.floor(center_image - ps_arr / 2.0).astype(np.int64)
start = (start // st_arr) * st_arr

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Align full-scan patch starts to contain mapped FG center

In the full-resolution fallback scanner, start positions are computed with floor(center - patch/2) and then snapped to stride, which can move the mapped foreground center outside the selected image patch when the affine is not stride-aligned. That causes valid foreground candidates to be mislocalized or dropped by the subsequent bounds check, so runs without labels_scan_level can miss substantial training signal.

Useful? React with 👍 / 👎.

Comment on lines +260 to +264
out_channels = int(task_cfg.get("out_channels", 0) or 0)
if out_channels <= 2:
return False
activation = str(task_cfg.get("activation", "none") or "none").lower()
return activation in {"none", "softmax", "identity"}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Gate multiclass rendering by task type, not channel count alone

The multiclass detector now classifies any target with out_channels > 2 and activation none/softmax/identity as segmentation, which also matches non-segmentation tasks like normals (out_channels: 3, activation: none). Those tasks are then routed through the softmax-cloud renderer instead of the 3-channel/vector path, producing misleading debug visualizations that hide real model behavior.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant