Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
012b92a
Add LocalNet module using only local (DISCO/conv1x1) filters
mcgibbon Mar 23, 2026
4de511d
Simplify LocalNet config: remove dropout/checkpointing, use block_typ…
mcgibbon Mar 23, 2026
7fadfee
Add AnkurLocalNet replicating Ankur's diagnostic MLP architecture
mcgibbon Mar 23, 2026
b1567ff
Move data_grid from config dataclasses to factory function arguments
mcgibbon Mar 23, 2026
c8950fa
Add HybridNet combining SFNO backbone with local diagnostic network
mcgibbon Mar 23, 2026
d2afaef
Merge branch 'main' into feature/separate_input_types
mcgibbon Mar 23, 2026
e42d56a
Remove superfluous HybridNet tests for sub-network internals
mcgibbon Mar 23, 2026
1c5b252
Add optional secondary network to SingleModuleStep and remove HybridNet
mcgibbon Mar 24, 2026
473a53d
Add AnkurLocalNet and LocalNet to the module registry
mcgibbon Mar 24, 2026
703f138
Use training step instead of parameter surgery in noise stochasticity…
mcgibbon Mar 24, 2026
48c7c5b
Extract shared NoiseConditionedModule base class
mcgibbon Mar 24, 2026
34abe3a
Use composition instead of inheritance for noise conditioning
mcgibbon Mar 24, 2026
b5e451c
Add regression tests for AnkurLocalNet and LocalNet
mcgibbon Mar 24, 2026
97e70a7
Split secondary network outputs into full-field and residual lists
mcgibbon Mar 24, 2026
708d634
Add isotropic noise support for LocalNet
mcgibbon Mar 24, 2026
a525a64
Consolidate noise conditioning into NoiseConditionedModel in stochast…
mcgibbon Mar 24, 2026
7ff1be1
Remove data_grid config from local network builders, hard-code to leg…
mcgibbon Mar 24, 2026
b806b9a
Add __init__ docstring to NoiseConditionedModel
mcgibbon Mar 24, 2026
9e23ed8
Merge branch 'main' into feature/separate_input_types
mcgibbon Mar 24, 2026
47c420c
Add regression tests for AnkurLocalNet and LocalNet
mcgibbon Mar 24, 2026
66e9e66
Remove HybridNet and its tests
mcgibbon Mar 24, 2026
f62ce1b
Merge branch 'feature/separate_input_types' into feature/secondary_st…
mcgibbon Mar 24, 2026
25c70ee
Merge branch 'main' into feature/secondary_stepping
mcgibbon Mar 25, 2026
d514f3c
Address PR review feedback: update defaults, remove duplicate docstri…
mcgibbon Mar 25, 2026
da2f020
Merge branch 'feature/secondary_stepping' of github.com:ai2cm/ace int…
mcgibbon Mar 25, 2026
0009d7e
Fix backwards compatibility: restore conditional_model attribute name
mcgibbon Mar 25, 2026
607932f
Add kernel_shape, basis_type config to LocalNet and isotropic Morlet …
mcgibbon Mar 25, 2026
ad163cf
Merge branch 'main' into feature/secondary_stepping
mcgibbon Mar 25, 2026
4c98ea3
Add test that isotropic DISCO conv commutes with latitude reflection
mcgibbon Mar 25, 2026
c0a911d
Fork DISCO convolution with FFT-based contraction into fme/core/disco
mcgibbon Mar 25, 2026
4d0dfe2
Remove torch-harmonics dependency from fme/core/disco
mcgibbon Mar 25, 2026
0a0711a
Merge branch 'feature/faster_disco' into feature/secondary_stepping
mcgibbon Mar 25, 2026
f84aae3
Replace isotropic Morlet monkey-patch with native filter basis
mcgibbon Mar 25, 2026
138c72d
Fix DISCO test compatibility with torch-harmonics 0.8.0
mcgibbon Mar 26, 2026
cb4ca83
Merge branch 'main' into feature/secondary_stepping
mcgibbon Mar 26, 2026
4a43eb5
Merge branch 'main' into feature/secondary_stepping
mcgibbon Mar 26, 2026
a024340
Add arxiv reference to AnkurLocalNet documentation
mcgibbon Mar 26, 2026
e2b7559
Fix DDP crash on integer DISCO index buffers
mcgibbon Mar 26, 2026
119c531
Add initial draft of vector filter basis notes
mcgibbon Mar 26, 2026
7dae1a2
Revert "Add initial draft of vector filter basis notes"
mcgibbon Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fme/ace/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from fme.core.models import mlp as _mlp

from . import land_net as _landnet
from . import local_net as _localnet
from . import m2lines as _m2lines
from . import prebuilt as _prebuilt
from . import sfno as _sfno
from . import stochastic_sfno as _sfno_crps
from .registry import ModuleSelector

del _prebuilt, _sfno, _m2lines, _landnet, _sfno_crps, _mlp
del _prebuilt, _sfno, _m2lines, _landnet, _localnet, _sfno_crps, _mlp
236 changes: 236 additions & 0 deletions fme/ace/registry/local_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import dataclasses
from typing import Literal

import torch
from torch import nn

from fme.ace.registry.registry import ModuleConfig, ModuleSelector
from fme.ace.registry.stochastic_sfno import NoiseConditionedModel
from fme.core.dataset_info import DatasetInfo
from fme.core.distributed import Distributed
from fme.core.models.conditional_sfno.ankur import (
AnkurLocalNetConfig,
get_lat_lon_ankur_localnet,
)
from fme.core.models.conditional_sfno.layers import Context, ContextConfig
from fme.core.models.conditional_sfno.localnet import (
BasisType,
BlockType,
LocalNetConfig,
get_lat_lon_localnet,
)


class _ContextWrappedModule(nn.Module):
"""Wraps a module that takes (x, context: Context) to accept (x, labels=None).

This adapts the conditional_sfno forward signature to the interface
expected by the Module registry wrapper.
"""

def __init__(self, module: nn.Module):
super().__init__()
self.module = module

def forward(
self, x: torch.Tensor, labels: torch.Tensor | None = None
) -> torch.Tensor:
context = Context(
embedding_scalar=None,
embedding_pos=None,
labels=labels,
noise=None,
)
return self.module(x, context)


@ModuleSelector.register("AnkurLocalNet")
@dataclasses.dataclass
class AnkurLocalNetBuilder(ModuleConfig):
"""Configuration for the AnkurLocalNet architecture.

A simple 3-hidden-layer MLP that optionally uses a DISCO convolution
for the first layer and a learned positional embedding.

Reference: https://arxiv.org/html/2602.16090v1

Attributes:
embed_dim: Dimension of the hidden layers.
use_disco_encoder: Whether to use a DISCO convolution for the first
layer instead of a 1x1 convolution.
disco_kernel_size: Kernel size for the DISCO convolution if used.
pos_embed: Whether to add a learned positional embedding after the
first layer.
activation_function: Activation function name ('relu', 'gelu', 'silu').
"""

embed_dim: int = 256
use_disco_encoder: bool = False
disco_kernel_size: int = 3
pos_embed: bool = False
activation_function: str = "gelu"

def build(
self,
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
params = AnkurLocalNetConfig(
embed_dim=self.embed_dim,
use_disco_encoder=self.use_disco_encoder,
disco_kernel_size=self.disco_kernel_size,
pos_embed=self.pos_embed,
activation_function=self.activation_function,
)
context_config = ContextConfig(
embed_dim_scalar=0,
embed_dim_noise=0,
embed_dim_labels=len(dataset_info.all_labels),
embed_dim_pos=0,
)
net = get_lat_lon_ankur_localnet(
params=params,
in_chans=n_in_channels,
out_chans=n_out_channels,
img_shape=dataset_info.img_shape,
data_grid="legendre-gauss",
context_config=context_config,
)
return _ContextWrappedModule(net)


@ModuleSelector.register("LocalNet")
@dataclasses.dataclass
class LocalNetBuilder(ModuleConfig):
"""Configuration for the LocalNet architecture.

A noise-conditioned local neural operator network using DISCO convolutions
and/or 1x1 convolutions, with encoder/decoder structure and optional skip
connections. Supports label conditioning when used with conditional=True
on the ModuleSelector.

Attributes:
embed_dim: Dimension of the embeddings.
kernel_shape: Shape of the DISCO convolution filter basis, passed
to the filter basis constructor. For the "piecewise linear" and
"morlet" basis types this is a two-element tuple
(n_radial_modes, n_azimuthal_modes). When n_azimuthal_modes is
1, the "piecewise linear" basis produces isotropic (radially
symmetric) filters. Only affects 'disco' blocks.
basis_type: Type of filter basis for the DISCO convolution
('morlet', 'piecewise linear', or 'zernike'). Only affects
'disco' blocks.
noise_embed_dim: Dimension of the noise conditioning channels.
noise_type: Type of noise for conditioning ('gaussian' or 'isotropic').
Isotropic noise is generated via inverse spherical harmonic
transform.
context_pos_embed_dim: Dimension of the learned positional embedding
used for conditioning. 0 disables.
block_types: List of filter types for each block ('disco', 'conv1x1').
The length determines the number of blocks.
global_layer_norm: Whether to reduce along the spatial domain when
applying layer normalization.
use_mlp: Whether to use an MLP in each block.
mlp_ratio: Ratio of MLP hidden dimension to the embedding dimension.
activation_function: Activation function name ('relu', 'gelu', 'silu').
encoder_layers: Number of convolutional layers in the encoder/decoder.
pos_embed: Whether to use a learned positional embedding inside the
LocalNet (distinct from context_pos_embed_dim).
big_skip: Whether to use a big skip connection from input to decoder.
normalize_big_skip: Whether to normalize the big skip connection.
affine_norms: Whether to use element-wise affine parameters in the
normalization layers.
lora_rank: Rank of LoRA adaptations. 0 disables LoRA.
lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank
if None.
"""

embed_dim: int = 256
kernel_shape: tuple[int, int] = (3, 3)
basis_type: BasisType = "morlet"
noise_embed_dim: int = 256
noise_type: Literal["gaussian", "isotropic"] = "gaussian"
context_pos_embed_dim: int = 0
block_types: list[BlockType] = dataclasses.field(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is this what we want for default? It seems a little odd to default to a rather deep (12) layer local model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think we will want to tune this and I have no idea what it will look like in the end, but for now I think I want to start at 4 disco blocks followed by 4 conv1x1 blocks. This gives similar depth to what we have in the SFNO case, while biasing it towards more local analysis. I do think we should also test with as low as 2 blocks (and of course Ankur's case which only uses 1 DISCO layer) and compare. For what it's worth, this is analogous to the SFNO class where the default config uses 12 blocks but we always override it to 8.

Unfortunately to some degree this will be resolution-dependent as well, so we may have to tune both on 4 and 1-degree data.

default_factory=lambda: [
"disco",
"disco",
"disco",
"disco",
"conv1x1",
"conv1x1",
"conv1x1",
"conv1x1",
]
)
global_layer_norm: bool = False
use_mlp: bool = True
mlp_ratio: float = 2.0
activation_function: str = "gelu"
encoder_layers: int = 1
pos_embed: bool = True
big_skip: bool = True
normalize_big_skip: bool = False
affine_norms: bool = False
lora_rank: int = 0
lora_alpha: float | None = None

def build(
self,
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
) -> nn.Module:
params = LocalNetConfig(
embed_dim=self.embed_dim,
kernel_shape=self.kernel_shape,
basis_type=self.basis_type,
block_types=self.block_types,
global_layer_norm=self.global_layer_norm,
use_mlp=self.use_mlp,
mlp_ratio=self.mlp_ratio,
activation_function=self.activation_function,
encoder_layers=self.encoder_layers,
pos_embed=self.pos_embed,
big_skip=self.big_skip,
normalize_big_skip=self.normalize_big_skip,
affine_norms=self.affine_norms,
lora_rank=self.lora_rank,
lora_alpha=self.lora_alpha,
)
embed_dim_labels = len(dataset_info.all_labels)
context_config = ContextConfig(
embed_dim_scalar=0,
embed_dim_noise=self.noise_embed_dim,
embed_dim_labels=embed_dim_labels,
embed_dim_pos=self.context_pos_embed_dim,
)
net = get_lat_lon_localnet(
params=params,
in_chans=n_in_channels,
out_chans=n_out_channels,
img_shape=dataset_info.img_shape,
data_grid="legendre-gauss",
context_config=context_config,
)
img_shape = dataset_info.img_shape
if self.noise_type == "isotropic":
dist = Distributed.get_instance()
inverse_sht = dist.get_isht(*img_shape, grid="legendre-gauss")
lmax = inverse_sht.lmax
mmax = inverse_sht.mmax
else:
inverse_sht = None
lmax = 0
mmax = 0
return NoiseConditionedModel(
net,
img_shape=img_shape,
embed_dim_noise=self.noise_embed_dim,
embed_dim_pos=self.context_pos_embed_dim,
embed_dim_labels=embed_dim_labels,
inverse_sht=inverse_sht,
lmax=lmax,
mmax=mmax,
)
66 changes: 47 additions & 19 deletions fme/ace/registry/stochastic_sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
SFNONetConfig,
get_lat_lon_sfnonet,
)
from fme.core.models.conditional_sfno.sfnonet import (
SphericalFourierNeuralOperatorNet as ConditionalSFNO,
)


def isotropic_noise(
Expand Down Expand Up @@ -48,21 +45,42 @@ def isotropic_noise(
return isht(alm)


class NoiseConditionedSFNO(torch.nn.Module):
class NoiseConditionedModel(torch.nn.Module):
"""Wraps a context-based module with noise and optional label conditioning.

Generates noise (gaussian by default, or isotropic via an inverse SHT)
and optional positional embeddings (with label-position interaction),
then calls the wrapped module with a fully populated Context.

Args:
module: An nn.Module with forward signature (x, context: Context).
img_shape: Global spatial dimensions (lat, lon) of the input data.
embed_dim_noise: Dimension of noise channels.
embed_dim_pos: Dimension of learned positional embedding. 0 disables.
embed_dim_labels: Dimension of label embeddings. 0 disables.
inverse_sht: Optional inverse spherical harmonic transform callable.
If provided, isotropic noise is generated via SHT; otherwise
gaussian noise is used.
"""

def __init__(
self,
conditional_model: ConditionalSFNO,
module: torch.nn.Module,
img_shape: tuple[int, int],
noise_type: Literal["isotropic", "gaussian"] = "gaussian",
embed_dim_noise: int = 256,
embed_dim_pos: int = 0,
embed_dim_labels: int = 0,
inverse_sht: Callable[[torch.Tensor], torch.Tensor] | None = None,
lmax: int = 0,
mmax: int = 0,
):
super().__init__()
self.conditional_model = conditional_model
self.conditional_model = module
self.embed_dim = embed_dim_noise
self.noise_type = noise_type
self.img_shape = img_shape
self._inverse_sht = inverse_sht
self._lmax = lmax
self._mmax = mmax
self.label_pos_embed: torch.nn.Parameter | None = None
# register pos embed if pos_embed_dim != 0
if embed_dim_pos != 0:
Expand Down Expand Up @@ -91,24 +109,20 @@ def forward(
self, x: torch.Tensor, labels: torch.Tensor | None = None
) -> torch.Tensor:
x = x.reshape(-1, *x.shape[-3:])
if self.noise_type == "isotropic":
lmax = self.conditional_model.itrans_up.lmax
mmax = self.conditional_model.itrans_up.mmax
if self._inverse_sht is not None:
noise = isotropic_noise(
(x.shape[0], self.embed_dim),
lmax,
mmax,
self.conditional_model.itrans_up,
self._lmax,
self._mmax,
self._inverse_sht,
device=x.device,
)
elif self.noise_type == "gaussian":
else:
noise = torch.randn(
[x.shape[0], self.embed_dim, *x.shape[-2:]],
device=x.device,
dtype=x.dtype,
)
else:
raise ValueError(f"Invalid noise type: {self.noise_type}")

h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape)

Expand All @@ -135,6 +149,10 @@ def forward(
)


# Backward-compatible alias
NoiseConditionedSFNO = NoiseConditionedModel


# this is based on the call signature of SphericalFourierNeuralOperatorNet at
# https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501
@ModuleSelector.register("NoiseConditionedSFNO")
Expand Down Expand Up @@ -293,11 +311,21 @@ def build(
embed_dim_labels=len(dataset_info.all_labels),
),
)
return NoiseConditionedSFNO(
if self.noise_type == "isotropic":
inverse_sht = sfno_net.itrans_up
lmax = inverse_sht.lmax
mmax = inverse_sht.mmax
else:
inverse_sht = None
lmax = 0
mmax = 0
return NoiseConditionedModel(
sfno_net,
noise_type=self.noise_type,
embed_dim_noise=self.noise_embed_dim,
embed_dim_pos=self.context_pos_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
img_shape=dataset_info.img_shape,
inverse_sht=inverse_sht,
lmax=lmax,
mmax=mmax,
)
Loading
Loading