Skip to content

Add secondary network support to SingleModuleStep and register local network types#1002

Open
mcgibbon wants to merge 40 commits intomainfrom
feature/secondary_stepping
Open

Add secondary network support to SingleModuleStep and register local network types#1002
mcgibbon wants to merge 40 commits intomainfrom
feature/secondary_stepping

Conversation

@mcgibbon
Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon commented Mar 24, 2026

Summary

  • Extend SingleModuleStepConfig with optional secondary_builder, secondary_out_names, and secondary_residual_names to support a secondary network that receives the same input as the backbone. The secondary network can predict diagnostic-only variables and/or add residual corrections to the backbone's prognostic outputs.

  • Add AnkurLocalNet (simple MLP) and LocalNet (DISCO/conv1x1 local neural operator) model implementations, with full test coverage and regression baselines.

  • Register both as ModuleSelector types ("AnkurLocalNet" and "LocalNet"), usable as either builder or secondary_builder in config YAML. LocalNet supports noise and label conditioning via NoiseConditionedModule.

  • Extract shared NoiseConditionedModule base used by both LocalNet and NoiseConditionedSFNO, using composition (noise generator callback) rather than inheritance.

  • Remove HybridNet which is superseded by the secondary network support in SingleModuleStep.

  • Tests added

mcgibbon and others added 13 commits March 23, 2026 15:10
LocalNet is structurally similar to SphericalFourierNeuralOperatorNet but
uses only local operations (DISCO convolutions and 1x1 convolutions) without
any spectral transforms or SHT round-tripping. Intended for use as a
component in a future hybrid network alongside SFNONet.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…es list

Replace filter_type, num_layers, and conv1x1_blocks with a single
block_types list that determines both the number of blocks and each
block's filter type. Also remove concat_skip, drop_rate, drop_path_rate,
and checkpointing which were carried over from SFNO but unnecessary here.
Validate block_types in __post_init__.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
A simple sequential network (3 hidden layers + projection) that serves as
a drop-in replacement for LocalNet. Optionally uses a DISCO convolution
with gcd-grouped channels for the first layer and a learned positional
embedding, matching the diagnostic_mlp from Ankur's
ColumnDiagnosticSphericalFourierNeuralOperatorNet.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
data_grid is determined automatically from the data, not configured in
YAML, so it should not be an attribute on the config dataclasses. Moved
it to an explicit parameter on get_lat_lon_sfnonet, get_lat_lon_localnet,
and get_lat_lon_ankur_localnet factory functions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Introduces a HybridNet module that compositionally pairs an SFNO backbone
(for prognostic output) with a local network (for diagnostic output),
analogous to Ankur's ColumnDiagnosticSFNO but using the separated
forcing/prognostic/diagnostic interface. Adds type discriminator tags
to AnkurLocalNetConfig and LocalNetConfig for union deserialization.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extend SingleModuleStepConfig with secondary_builder, secondary_out_names,
and secondary_residual_names fields to support a secondary network that
receives the same input as the backbone. The secondary network can predict
diagnostic-only variables and/or add residual corrections to the backbone's
prognostic outputs. This subsumes the functionality of HybridNet, which is
deleted along with its tests and test data.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Register AnkurLocalNet (simple context-wrapped MLP) and LocalNet
(noise-conditioned local neural operator) as ModuleSelector builders.
LocalNet uses a NoiseConditionedModule wrapper with gaussian noise and
optional positional/label conditioning, and is added to CONDITIONAL_BUILDERS
to support label-conditioned training.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… test

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move the noise conditioning, positional embedding, and label-position
interaction logic into a shared NoiseConditionedModule in
noise_conditioned.py. NoiseConditionedSFNO now subclasses it, overriding
_generate_noise to support isotropic noise via SHT. LocalNet uses the
base class directly with gaussian noise.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace NoiseConditionedSFNO subclass with a factory function that passes
a noise_generator callable to NoiseConditionedModule. The base class
accepts a NoiseGenerator callback (defaults to gaussian_noise), and the
SFNO factory creates an isotropic noise generator that captures the
model's inverse SHT.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the regression coverage lost when HybridNet was removed.
Each test uses validate_tensor to compare output against a saved
baseline, following the same pattern as the SFNO regression tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@climate-ci-github climate-ci-github changed the title Feature/secondary stepping Add secondary network support to SingleModuleStep and register local network types Mar 24, 2026
torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02)
else:
self.pos_embed = None
def _make_sfno_noise_generator(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to allow isotropic noise for the local net as well. Refactor so it can share this code.

out_names=["prog_a", "prog_b"],
normalization=normalization,
secondary_builder=ModuleSelector(type="MLP", config={}),
secondary_out_names=["prog_a", "diag_a"],
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, hm. I had meant for it to be that "out" names are full field outputs only, and the residual names exist only in that list. Please rename the second attribute to secondary_residual_out_names and give it that behavior.

mcgibbon and others added 2 commits March 24, 2026 18:38
Replace the single secondary_residual_out_names with two lists:
- secondary_out_names: full-field outputs used directly (must not
  overlap with out_names)
- secondary_residual_out_names: residual corrections added to the
  backbone output (if name is in out_names) or to the normalized
  input (if name is only in in_names)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move isotropic_noise and noise generator factory into
noise_conditioned.py as shared infrastructure. LocalNetBuilder gains a
noise_type field ('gaussian' or 'isotropic'); when isotropic, it creates
a standalone InverseRealSHT for noise generation. The SFNO builder uses
the same make_noise_generator factory, extracting the iSHT from its
model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
"legendre-gauss": "legendre-gauss",
"equiangular": "equiangular",
}
isht = InverseRealSHT(*img_shape, grid=grid_mapping[self.data_grid])
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the dist getter method for this rather than retrieving it directly.

if self.noise_type == "isotropic":
from torch_harmonics import InverseRealSHT

grid_mapping = {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an identity mapping, right? I think we should remove it.

return gaussian_noise
elif noise_type == "isotropic":
if isht is None:
raise ValueError("isht is required for isotropic noise")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the lack of type safety that leads to the need for this ValueError, is there a way to avoid it? For example, perhaps this helper should be deleted and we should just directly implement this chain one level higher, where we construct the isht?

mcgibbon and others added 2 commits March 24, 2026 18:58
…ic_sfno

Address review feedback:
- Rename NoiseConditionedSFNO to NoiseConditionedModel with an optional
  inverse_sht parameter (None = gaussian, provided = isotropic noise).
  Keep NoiseConditionedSFNO as a backward-compatible alias.
- Delete noise_conditioned.py; local_net.py imports from stochastic_sfno.
- Remove identity grid_mapping dict in LocalNet builder.
- Use dist.get_isht() for creating the inverse SHT in LocalNet builder.
- Delete make_noise_generator helper to avoid type-unsafe ValueError.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…endre-gauss

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Document that passing inverse_sht switches noise from gaussian to
isotropic, and that lmax/mmax are required when inverse_sht is provided.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mcgibbon mcgibbon marked this pull request as ready for review March 24, 2026 19:07
mcgibbon and others added 4 commits March 24, 2026 15:22
Replace the regression coverage lost when HybridNet was removed.
Each test uses validate_tensor to compare output against a saved
baseline, following the same pattern as the SFNO regression tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This functionality is superseded by the secondary network support
in SingleModuleStep on the follow-on branch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@Arcomano1234 Arcomano1234 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach is better than with a HybridNet. Left a few questions / comments but this is looking good.

noise_embed_dim: int = 256
noise_type: Literal["gaussian", "isotropic"] = "gaussian"
context_pos_embed_dim: int = 0
block_types: list[BlockType] = dataclasses.field(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is this what we want for default? It seems a little odd to default to a rather deep (12) layer local model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think we will want to tune this and I have no idea what it will look like in the end, but for now I think I want to start at 4 disco blocks followed by 4 conv1x1 blocks. This gives similar depth to what we have in the SFNO case, while biasing it towards more local analysis. I do think we should also test with as low as 2 blocks (and of course Ankur's case which only uses 1 DISCO layer) and compare. For what it's worth, this is analogous to the SFNO class where the default config uses 12 blocks but we always override it to 8.

Unfortunately to some degree this will be resolution-dependent as well, so we may have to tune both on 4 and 1-degree data.

Base automatically changed from feature/separate_input_types to main March 25, 2026 16:08
mcgibbon and others added 17 commits March 25, 2026 17:08
…ng, add tests

- Change LocalNetBuilder default block_types to 4 disco + 4 conv1x1 blocks
- Remove duplicate __init__ docstring from NoiseConditionedModel (keep class-level)
- Add tests for secondary config output_names property permutations
- Add test for secondary residual with residual_prediction=True scenario

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The rename from NoiseConditionedSFNO to NoiseConditionedModel also changed
the wrapped module attribute from conditional_model to module, breaking
frozen checkpoint loading. Restore the original attribute name.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…basis

Expose kernel_shape and basis_type as configurable options on LocalNetConfig
and LocalNetBuilder, replacing the hardcoded (3, 3) morlet defaults in DISCO
convolution layers. Add an isotropic Morlet filter basis (radial-only modes)
via monkey-patch of torch-harmonics get_filter_basis, available as the
"isotropic morlet" basis type.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Verifies that the isotropic Morlet basis produces truly isotropic filters
by checking that flip(conv(x)) == conv(flip(x)) under latitude reflection,
an isometry that preserves geodesic distance but reverses local azimuthal
orientation. Also verifies the standard Morlet basis does NOT commute,
confirming the test discriminates isotropic from anisotropic filters.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Fork DiscreteContinuousConvS2 from torch-harmonics (fft-disco-contraction
branch) into fme/core/disco/, using FFT-based contraction as the sole
implementation path. This replaces the O(nlon) loop of sparse BMMs with
a single FFT → banded pointwise multiply → inverse FFT pipeline, giving
a significant speedup for LocalNet.

The non-distributed and data-parallel backends now use the forked class;
the spatial-parallel backend (model_torch_distributed) continues to use
the torch-harmonics distributed version unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Fork the remaining torch-harmonics utilities (filter_basis, quadrature,
cache, fft) into fme/core/disco/ so the DISCO convolution module is fully
self-contained. torch-harmonics is still used in tests for reference
comparison.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move IsotropicMorletFilterBasis from the torch-harmonics monkey-patch
module into fme/core/disco/_filter_basis.py and register it directly in
get_filter_basis. Remove the now-unnecessary isotropic_filter.py and its
import from localnet.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove use_fft_contraction kwarg that only exists in torch-harmonics
0.8.1+, while CI pins 0.8.0. Omitting it gives identical behavior on
both versions (sparse-matrix path is the default).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reference: https://arxiv.org/html/2602.16090v1

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
DDP's _broadcast_coalesced cannot handle non-float dtypes. Store
integer index tensors (psi_ker_idx, psi_row_idx, psi_col_idx,
psi_gather_idx) as plain attributes instead of registered buffers,
with an _apply override to ensure they follow device transfers.
These tensors are deterministic across ranks so sync is unnecessary.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants