From 012b92adae6fe05367ec40cf2a1e44993375c59c Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 15:10:02 +0000 Subject: [PATCH 01/31] Add LocalNet module using only local (DISCO/conv1x1) filters LocalNet is structurally similar to SphericalFourierNeuralOperatorNet but uses only local operations (DISCO convolutions and 1x1 convolutions) without any spectral transforms or SHT round-tripping. Intended for use as a component in a future hybrid network alongside SFNONet. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/models/conditional_sfno/localnet.py | 486 ++++++++++++++++++ .../models/conditional_sfno/test_localnet.py | 268 ++++++++++ 2 files changed, 754 insertions(+) create mode 100644 fme/core/models/conditional_sfno/localnet.py create mode 100644 fme/core/models/conditional_sfno/test_localnet.py diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py new file mode 100644 index 000000000..003edf14d --- /dev/null +++ b/fme/core/models/conditional_sfno/localnet.py @@ -0,0 +1,486 @@ +import dataclasses +from collections.abc import Callable + +import torch +import torch.nn as nn + +from fme.core.benchmark.timer import NullTimer, Timer +from fme.core.distributed import Distributed + +from .initialization import trunc_normal_ +from .layers import MLP, ConditionalLayerNorm, Context, ContextConfig +from .lora import LoRAConv2d +from .sfnonet import DiscreteContinuousConvS2, NoLayerNorm, _compute_cutoff_radius + + +@dataclasses.dataclass +class LocalNetConfig: + """Configuration parameters for LocalNet. + + Attributes: + embed_dim: Dimension of the embeddings. + filter_type: Type of filter to use ('disco', 'conv1x1'). + global_layer_norm: Whether to reduce along the spatial domain when + applying layer normalization. + num_layers: Number of blocks. + use_mlp: Whether to use an MLP in each block. + mlp_ratio: Ratio of MLP hidden dimension to the embedding dimension. + activation_function: Activation function name ('relu', 'gelu', 'silu'). + encoder_layers: Number of convolutional layers in the encoder/decoder. + pos_embed: Whether to use a learned positional embedding. + big_skip: Whether to use a big skip connection from input to decoder. + conv1x1_blocks: List of block indices to use 1x1 convolutions + instead of the default filter type. + normalize_big_skip: Whether to normalize the big skip connection. + affine_norms: Whether to use element-wise affine parameters in the + normalization layers. + lora_rank: Rank of LoRA adaptations. 0 disables LoRA. + lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank + if None. + data_grid: Grid type for DISCO convolutions + ('equiangular', 'legendre-gauss'). + """ + + embed_dim: int = 256 + filter_type: str = "disco" + global_layer_norm: bool = False + num_layers: int = 12 + use_mlp: bool = True + mlp_ratio: float = 2.0 + activation_function: str = "gelu" + encoder_layers: int = 1 + pos_embed: bool = True + big_skip: bool = True + conv1x1_blocks: list[int] | None = None + normalize_big_skip: bool = False + affine_norms: bool = False + lora_rank: int = 0 + lora_alpha: float | None = None + data_grid: str = "equiangular" + + +class Conv1x1Filter(nn.Module): + """1x1 convolution used as a local filter.""" + + def __init__(self, embed_dim, lora_rank=0, lora_alpha=None): + super().__init__() + self.conv = LoRAConv2d( + embed_dim, + embed_dim, + 1, + bias=True, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + + def forward(self, x, timer: Timer = NullTimer()): + return self.conv(x), x + + +class LocalFilterLayer(nn.Module): + """Local filter layer using either DISCO convolution or 1x1 convolution.""" + + def __init__( + self, + embed_dim, + img_shape: tuple[int, int], + filter_type="disco", + data_grid="equiangular", + lora_rank: int = 0, + lora_alpha: float | None = None, + ): + super().__init__() + + if filter_type == "disco": + nlat, nlon = img_shape + theta_cutoff = 2 * _compute_cutoff_radius( + nlat=nlat, + kernel_shape=(3, 3), + basis_type="morlet", + ) + self.filter = DiscreteContinuousConvS2( + embed_dim, + embed_dim, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + groups=1, + grid_in=data_grid, + grid_out=data_grid, + bias=False, + theta_cutoff=theta_cutoff, + ) + elif filter_type == "conv1x1": + self.filter = Conv1x1Filter( + embed_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + else: + raise NotImplementedError(f"Unknown filter type: {filter_type}") + + def forward(self, x, timer: Timer = NullTimer()): + return self.filter(x, timer=timer) + + +class LocalBlock(nn.Module): + """Block using local (non-spectral) filters.""" + + def __init__( + self, + embed_dim, + img_shape: tuple[int, int], + context_config: ContextConfig, + filter_type="disco", + data_grid="equiangular", + global_layer_norm: bool = False, + mlp_ratio=2.0, + act_layer=nn.GELU, + inner_skip="linear", + outer_skip=None, + use_mlp=False, + affine_norms=False, + lora_rank: int = 0, + lora_alpha: float | None = None, + ): + super().__init__() + + self.input_shape_loc = img_shape + self.output_shape_loc = img_shape + + # norm layer + self.norm0 = ConditionalLayerNorm( + embed_dim, + img_shape=self.input_shape_loc, + global_layer_norm=global_layer_norm, + context_config=context_config, + elementwise_affine=affine_norms, + ) + + # local filter + self.filter = LocalFilterLayer( + embed_dim, + img_shape=img_shape, + filter_type=filter_type, + data_grid=data_grid, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + + if inner_skip == "linear": + self.inner_skip = LoRAConv2d( + embed_dim, + embed_dim, + 1, + 1, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + elif inner_skip == "identity": + self.inner_skip = nn.Identity() + + if filter_type == "conv1x1": + self.act_layer = act_layer() + + # norm layer + self.norm1 = ConditionalLayerNorm( + embed_dim, + img_shape=self.output_shape_loc, + global_layer_norm=global_layer_norm, + context_config=context_config, + elementwise_affine=affine_norms, + ) + + if use_mlp: + mlp_hidden_dim = int(embed_dim * mlp_ratio) + self.mlp = MLP( + in_features=embed_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + + if outer_skip == "linear": + self.outer_skip = LoRAConv2d( + embed_dim, + embed_dim, + 1, + 1, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + ) + elif outer_skip == "identity": + self.outer_skip = nn.Identity() + + def forward(self, x, context_embedding, timer: Timer = NullTimer()): + with timer.child("norm0") as norm0_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = ( + self.norm0( + x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]], + context_embedding, + timer=norm0_timer, + ) + ) + with timer.child("filter") as filter_timer: + x, residual = self.filter(x_norm, timer=filter_timer) + if hasattr(self, "inner_skip"): + with timer.child("inner_skip"): + x = x + self.inner_skip(residual) + + if hasattr(self, "act_layer"): + with timer.child("activation"): + x = self.act_layer(x) + + with timer.child("norm1") as norm1_timer: + x_norm = torch.zeros_like(x) + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = ( + self.norm1( + x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]], + context_embedding, + timer=norm1_timer, + ) + ) + x = x_norm + + if hasattr(self, "mlp"): + with timer.child("mlp"): + x = self.mlp(x) + + if hasattr(self, "outer_skip"): + with timer.child("outer_skip"): + x = x + self.outer_skip(residual) + + return x + + +def get_lat_lon_localnet( + params: LocalNetConfig, + in_chans: int, + out_chans: int, + img_shape: tuple[int, int], + context_config: ContextConfig = ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=0, + embed_dim_labels=0, + embed_dim_pos=0, + ), +) -> "LocalNet": + h, w = img_shape + + def get_pos_embed(): + pos_embed = nn.Parameter(torch.zeros(1, params.embed_dim, h, w)) + pos_embed.is_shared_mp = ["matmul"] + trunc_normal_(pos_embed, std=0.02) + return pos_embed + + net = LocalNet( + params, + img_shape=img_shape, + in_chans=in_chans, + out_chans=out_chans, + context_config=context_config, + get_pos_embed=get_pos_embed, + ) + return net + + +class LocalNet(torch.nn.Module): + """Local Neural Operator Network. + + Uses only local operations (DISCO convolutions and 1x1 convolutions) + without any spectral transforms. + + Args: + params: Model configuration. See ``LocalNetConfig`` for details. + img_shape: Spatial dimensions (lat, lon) of the input data. + get_pos_embed: Factory function that returns a learned positional + embedding parameter. + in_chans: Number of input channels. + out_chans: Number of output channels. + context_config: Configuration for conditional context embeddings + (scalar, noise, positional, labels). + """ + + def __init__( + self, + params: LocalNetConfig, + img_shape: tuple[int, int], + get_pos_embed: Callable[[], nn.Parameter], + in_chans: int, + out_chans: int, + context_config: ContextConfig = ContextConfig( + embed_dim_scalar=0, + embed_dim_labels=0, + embed_dim_noise=0, + embed_dim_pos=0, + ), + ): + super().__init__() + + self.filter_type = params.filter_type + self.mlp_ratio = params.mlp_ratio + self.img_shape = img_shape + self._spatial_h_slice, self._spatial_w_slice = ( + Distributed.get_instance().get_local_slices(self.img_shape) + ) + self.global_layer_norm = params.global_layer_norm + self.in_chans = in_chans + self.out_chans = out_chans + self.embed_dim = params.embed_dim + self.num_layers = params.num_layers + self.use_mlp = params.use_mlp + self.encoder_layers = params.encoder_layers + self._use_pos_embed = params.pos_embed + self.big_skip = params.big_skip + if params.conv1x1_blocks is not None: + self.conv1x1_blocks = [ + i for i in range(self.num_layers) if i in params.conv1x1_blocks + ] + else: + self.conv1x1_blocks = [] + self.affine_norms = params.affine_norms + self.lora_rank = params.lora_rank + self.lora_alpha = params.lora_alpha + self.data_grid = params.data_grid + + # determine activation function + activation_functions = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU} + if params.activation_function not in activation_functions: + raise ValueError( + f"Unknown activation function {params.activation_function}" + ) + act_layer = activation_functions[params.activation_function] + + # encoder + encoder_hidden_dim = self.embed_dim + current_dim = self.in_chans + encoder_modules = [] + for i in range(self.encoder_layers): + encoder_modules.append( + LoRAConv2d( + current_dim, + encoder_hidden_dim, + 1, + bias=True, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + encoder_modules.append(act_layer()) + current_dim = encoder_hidden_dim + encoder_modules.append( + LoRAConv2d( + current_dim, + self.embed_dim, + 1, + bias=False, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + self.encoder = nn.Sequential(*encoder_modules) + + # blocks + self.blocks = nn.ModuleList([]) + for i in range(self.num_layers): + if i in self.conv1x1_blocks: + block_filter_type = "conv1x1" + else: + block_filter_type = self.filter_type + + inner_skip = "linear" + outer_skip = "identity" + + block = LocalBlock( + self.embed_dim, + img_shape=self.img_shape, + context_config=context_config, + filter_type=block_filter_type, + data_grid=self.data_grid, + global_layer_norm=self.global_layer_norm, + mlp_ratio=self.mlp_ratio, + act_layer=act_layer, + inner_skip=inner_skip, + outer_skip=outer_skip, + use_mlp=self.use_mlp, + affine_norms=self.affine_norms, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + + self.blocks.append(block) + + # decoder + decoder_hidden_dim = self.embed_dim + current_dim = self.embed_dim + self.big_skip * self.in_chans + decoder_modules = [] + for i in range(self.encoder_layers): + decoder_modules.append( + LoRAConv2d( + current_dim, + decoder_hidden_dim, + 1, + bias=True, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + decoder_modules.append(act_layer()) + current_dim = decoder_hidden_dim + decoder_modules.append( + LoRAConv2d( + current_dim, + self.out_chans, + 1, + bias=False, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + ) + self.decoder = nn.Sequential(*decoder_modules) + + # learned position embedding + if self._use_pos_embed: + self.pos_embed = get_pos_embed() + else: + self.pos_embed = None + + if params.normalize_big_skip: + self.norm_big_skip = ConditionalLayerNorm( + in_chans, + img_shape=self.img_shape, + global_layer_norm=self.global_layer_norm, + context_config=context_config, + elementwise_affine=self.affine_norms, + ) + else: + self.norm_big_skip = NoLayerNorm() + + def _forward_features(self, x: torch.Tensor, context: Context): + for blk in self.blocks: + x = blk(x, context) + + return x + + def forward(self, x: torch.Tensor, context: Context): + # save big skip + if self.big_skip: + residual = self.norm_big_skip(x, context=context) + + x = self.encoder(x) + + if self.pos_embed is not None: + x = x + self.pos_embed[..., self._spatial_h_slice, self._spatial_w_slice] + + x = self._forward_features(x, context) + + if self.big_skip: + x = torch.cat((x, residual), dim=1) + + x = self.decoder(x) + + return x diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py new file mode 100644 index 000000000..409d600f7 --- /dev/null +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -0,0 +1,268 @@ +import pytest +import torch +from torch import nn + +from fme.core.device import get_device + +from .layers import Context, ContextConfig +from .localnet import LocalNetConfig, get_lat_lon_localnet + + +@pytest.mark.parametrize( + "conditional_embed_dim_scalar, conditional_embed_dim_labels, " + "conditional_embed_dim_noise, " + "conditional_embed_dim_pos", + [ + (0, 0, 0, 0), + (16, 8, 0, 0), + (16, 0, 16, 0), + (16, 15, 14, 13), + (0, 0, 0, 16), + (0, 0, 16, 0), + ], +) +def test_can_call_localnet_disco( + conditional_embed_dim_scalar: int, + conditional_embed_dim_labels: int, + conditional_embed_dim_noise: int, + conditional_embed_dim_pos: int, +): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + num_layers=2, + filter_type="disco", + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_pos=conditional_embed_dim_pos, + ), + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context_embedding = torch.randn( + n_samples, conditional_embed_dim_scalar, device=device + ) + context_embedding_labels = torch.randn( + n_samples, conditional_embed_dim_labels, device=device + ) + context_embedding_noise = torch.randn( + n_samples, conditional_embed_dim_noise, *img_shape, device=device + ) + context_embedding_pos = torch.randn( + n_samples, conditional_embed_dim_pos, *img_shape, device=device + ) + context = Context( + embedding_scalar=context_embedding, + labels=context_embedding_labels, + noise=context_embedding_noise, + embedding_pos=context_embedding_pos, + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +def test_can_call_localnet_conv1x1(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + num_layers=2, + filter_type="conv1x1", + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +def test_can_call_localnet_mixed_blocks(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + num_layers=4, + filter_type="disco", + conv1x1_blocks=[1, 3], + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +@pytest.mark.parametrize("normalize_big_skip", [True, False]) +def test_all_inputs_get_layer_normed(normalize_big_skip: bool): + torch.manual_seed(0) + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + conditional_embed_dim_scalar = 8 + conditional_embed_dim_noise = 16 + conditional_embed_dim_labels = 3 + conditional_embed_dim_pos = 12 + device = get_device() + + class SetToZero(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x): + return torch.zeros_like(x) + + original_layer_norm = nn.LayerNorm + try: + nn.LayerNorm = SetToZero + params = LocalNetConfig( + embed_dim=16, + num_layers=2, + normalize_big_skip=normalize_big_skip, + global_layer_norm=True, # so it uses nn.LayerNorm + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=conditional_embed_dim_scalar, + embed_dim_noise=conditional_embed_dim_noise, + embed_dim_labels=conditional_embed_dim_labels, + embed_dim_pos=conditional_embed_dim_pos, + ), + ).to(device) + finally: + nn.LayerNorm = original_layer_norm + x = torch.full((n_samples, input_channels, *img_shape), torch.nan).to(device) + context_embedding = torch.randn(n_samples, conditional_embed_dim_scalar).to(device) + context_embedding_noise = torch.randn( + n_samples, conditional_embed_dim_noise, *img_shape + ).to(device) + context_embedding_labels = torch.randn(n_samples, conditional_embed_dim_labels).to( + device + ) + context_embedding_pos = torch.randn( + n_samples, conditional_embed_dim_pos, *img_shape + ).to(device) + context = Context( + embedding_scalar=context_embedding, + embedding_pos=context_embedding_pos, + noise=context_embedding_noise, + labels=context_embedding_labels, + ) + with torch.no_grad(): + output = model(x, context) + if normalize_big_skip: + assert not torch.isnan(output).any() + else: + assert torch.isnan(output).any() + + +def test_no_big_skip(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + num_layers=2, + big_skip=False, + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +def test_unknown_filter_type_raises(): + with pytest.raises(NotImplementedError): + LocalNetConfig(filter_type="spectral") + get_lat_lon_localnet( + params=LocalNetConfig(filter_type="spectral"), + img_shape=(9, 18), + in_chans=2, + out_chans=3, + ) + + +def test_backward_pass(): + """Test that gradients flow through the network.""" + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + num_layers=2, + filter_type="disco", + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + loss = output.sum() + loss.backward() + for name, param in model.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" From 4de511df400a6541e68ba4fceac79d8eb93ca9fd Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 15:23:16 +0000 Subject: [PATCH 02/31] Simplify LocalNet config: remove dropout/checkpointing, use block_types list Replace filter_type, num_layers, and conv1x1_blocks with a single block_types list that determines both the number of blocks and each block's filter type. Also remove concat_skip, drop_rate, drop_path_rate, and checkpointing which were carried over from SFNO but unnecessary here. Validate block_types in __post_init__. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/models/conditional_sfno/localnet.py | 43 +++++++++---------- .../models/conditional_sfno/test_localnet.py | 27 ++++-------- 2 files changed, 29 insertions(+), 41 deletions(-) diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py index 003edf14d..4b62a0f83 100644 --- a/fme/core/models/conditional_sfno/localnet.py +++ b/fme/core/models/conditional_sfno/localnet.py @@ -1,5 +1,6 @@ import dataclasses from collections.abc import Callable +from typing import Literal, get_args import torch import torch.nn as nn @@ -12,6 +13,8 @@ from .lora import LoRAConv2d from .sfnonet import DiscreteContinuousConvS2, NoLayerNorm, _compute_cutoff_radius +BlockType = Literal["disco", "conv1x1"] + @dataclasses.dataclass class LocalNetConfig: @@ -19,18 +22,16 @@ class LocalNetConfig: Attributes: embed_dim: Dimension of the embeddings. - filter_type: Type of filter to use ('disco', 'conv1x1'). + block_types: List of filter types for each block ('disco', 'conv1x1'). + The length determines the number of blocks. global_layer_norm: Whether to reduce along the spatial domain when applying layer normalization. - num_layers: Number of blocks. use_mlp: Whether to use an MLP in each block. mlp_ratio: Ratio of MLP hidden dimension to the embedding dimension. activation_function: Activation function name ('relu', 'gelu', 'silu'). encoder_layers: Number of convolutional layers in the encoder/decoder. pos_embed: Whether to use a learned positional embedding. big_skip: Whether to use a big skip connection from input to decoder. - conv1x1_blocks: List of block indices to use 1x1 convolutions - instead of the default filter type. normalize_big_skip: Whether to normalize the big skip connection. affine_norms: Whether to use element-wise affine parameters in the normalization layers. @@ -42,22 +43,31 @@ class LocalNetConfig: """ embed_dim: int = 256 - filter_type: str = "disco" + block_types: list[BlockType] = dataclasses.field( + default_factory=lambda: ["disco"] * 12 + ) global_layer_norm: bool = False - num_layers: int = 12 use_mlp: bool = True mlp_ratio: float = 2.0 activation_function: str = "gelu" encoder_layers: int = 1 pos_embed: bool = True big_skip: bool = True - conv1x1_blocks: list[int] | None = None normalize_big_skip: bool = False affine_norms: bool = False lora_rank: int = 0 lora_alpha: float | None = None data_grid: str = "equiangular" + def __post_init__(self): + valid = get_args(BlockType) + for i, bt in enumerate(self.block_types): + if bt not in valid: + raise ValueError( + f"Invalid block type {bt!r} at index {i}, " + f"must be one of {valid}" + ) + class Conv1x1Filter(nn.Module): """1x1 convolution used as a local filter.""" @@ -321,7 +331,7 @@ def __init__( ): super().__init__() - self.filter_type = params.filter_type + self.block_types = params.block_types self.mlp_ratio = params.mlp_ratio self.img_shape = img_shape self._spatial_h_slice, self._spatial_w_slice = ( @@ -331,17 +341,11 @@ def __init__( self.in_chans = in_chans self.out_chans = out_chans self.embed_dim = params.embed_dim - self.num_layers = params.num_layers + self.num_layers = len(params.block_types) self.use_mlp = params.use_mlp self.encoder_layers = params.encoder_layers self._use_pos_embed = params.pos_embed self.big_skip = params.big_skip - if params.conv1x1_blocks is not None: - self.conv1x1_blocks = [ - i for i in range(self.num_layers) if i in params.conv1x1_blocks - ] - else: - self.conv1x1_blocks = [] self.affine_norms = params.affine_norms self.lora_rank = params.lora_rank self.lora_alpha = params.lora_alpha @@ -386,12 +390,7 @@ def __init__( # blocks self.blocks = nn.ModuleList([]) - for i in range(self.num_layers): - if i in self.conv1x1_blocks: - block_filter_type = "conv1x1" - else: - block_filter_type = self.filter_type - + for block_type in self.block_types: inner_skip = "linear" outer_skip = "identity" @@ -399,7 +398,7 @@ def __init__( self.embed_dim, img_shape=self.img_shape, context_config=context_config, - filter_type=block_filter_type, + filter_type=block_type, data_grid=self.data_grid, global_layer_norm=self.global_layer_norm, mlp_ratio=self.mlp_ratio, diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py index 409d600f7..886444d53 100644 --- a/fme/core/models/conditional_sfno/test_localnet.py +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -34,8 +34,7 @@ def test_can_call_localnet_disco( device = get_device() params = LocalNetConfig( embed_dim=16, - num_layers=2, - filter_type="disco", + block_types=["disco", "disco"], ) model = get_lat_lon_localnet( params=params, @@ -80,8 +79,7 @@ def test_can_call_localnet_conv1x1(): device = get_device() params = LocalNetConfig( embed_dim=16, - num_layers=2, - filter_type="conv1x1", + block_types=["conv1x1", "conv1x1"], ) model = get_lat_lon_localnet( params=params, @@ -108,9 +106,7 @@ def test_can_call_localnet_mixed_blocks(): device = get_device() params = LocalNetConfig( embed_dim=16, - num_layers=4, - filter_type="disco", - conv1x1_blocks=[1, 3], + block_types=["disco", "conv1x1", "disco", "conv1x1"], ) model = get_lat_lon_localnet( params=params, @@ -154,7 +150,7 @@ def forward(self, x): nn.LayerNorm = SetToZero params = LocalNetConfig( embed_dim=16, - num_layers=2, + block_types=["disco", "disco"], normalize_big_skip=normalize_big_skip, global_layer_norm=True, # so it uses nn.LayerNorm ) @@ -205,7 +201,7 @@ def test_no_big_skip(): device = get_device() params = LocalNetConfig( embed_dim=16, - num_layers=2, + block_types=["disco", "disco"], big_skip=False, ) model = get_lat_lon_localnet( @@ -226,14 +222,8 @@ def test_no_big_skip(): def test_unknown_filter_type_raises(): - with pytest.raises(NotImplementedError): - LocalNetConfig(filter_type="spectral") - get_lat_lon_localnet( - params=LocalNetConfig(filter_type="spectral"), - img_shape=(9, 18), - in_chans=2, - out_chans=3, - ) + with pytest.raises(ValueError, match="Invalid block type"): + LocalNetConfig(block_types=["spectral"]) # type: ignore[list-item] def test_backward_pass(): @@ -245,8 +235,7 @@ def test_backward_pass(): device = get_device() params = LocalNetConfig( embed_dim=16, - num_layers=2, - filter_type="disco", + block_types=["disco", "disco"], ) model = get_lat_lon_localnet( params=params, From 7fadfee91dad7a99a45582c4f2ab39bea6d7388a Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 16:07:07 +0000 Subject: [PATCH 03/31] Add AnkurLocalNet replicating Ankur's diagnostic MLP architecture A simple sequential network (3 hidden layers + projection) that serves as a drop-in replacement for LocalNet. Optionally uses a DISCO convolution with gcd-grouped channels for the first layer and a learned positional embedding, matching the diagnostic_mlp from Ankur's ColumnDiagnosticSphericalFourierNeuralOperatorNet. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/models/conditional_sfno/ankur.py | 159 ++++++++++++++++++ .../models/conditional_sfno/test_ankur.py | 159 ++++++++++++++++++ 2 files changed, 318 insertions(+) create mode 100644 fme/core/models/conditional_sfno/ankur.py create mode 100644 fme/core/models/conditional_sfno/test_ankur.py diff --git a/fme/core/models/conditional_sfno/ankur.py b/fme/core/models/conditional_sfno/ankur.py new file mode 100644 index 000000000..d6b78f0f9 --- /dev/null +++ b/fme/core/models/conditional_sfno/ankur.py @@ -0,0 +1,159 @@ +import dataclasses +import math + +import torch +import torch.nn as nn + +from fme.core.distributed import Distributed + +from .initialization import trunc_normal_ +from .layers import Context, ContextConfig +from .sfnonet import _compute_cutoff_radius + + +@dataclasses.dataclass +class AnkurLocalNetConfig: + """Configuration parameters for AnkurLocalNet. + + Replicates the diagnostic MLP architecture from Ankur's + ColumnDiagnosticSphericalFourierNeuralOperatorNet. + + Attributes: + embed_dim: Dimension of the hidden layers. + use_disco_encoder: Whether to use a DISCO convolution for the first + layer instead of a 1x1 convolution. + disco_kernel_size: Kernel size for the DISCO convolution if used. + pos_embed: Whether to add a learned positional embedding after the + first layer. + activation_function: Activation function name ('relu', 'gelu', 'silu'). + data_grid: Grid type for DISCO convolutions + ('equiangular', 'legendre-gauss'). + """ + + embed_dim: int = 256 + use_disco_encoder: bool = False + disco_kernel_size: int = 3 + pos_embed: bool = False + activation_function: str = "gelu" + data_grid: str = "equiangular" + + +class GroupedDiscreteContinuousConvS2(nn.Module): + """DISCO convolution using groups=gcd(in_chans, out_chans).""" + + def __init__(self, in_chans, out_chans, img_shape, kernel_size, data_grid): + super().__init__() + nlat, nlon = img_shape + kernel_shape = (kernel_size, kernel_size) + theta_cutoff = _compute_cutoff_radius( + nlat=nlat, + kernel_shape=kernel_shape, + basis_type="morlet", + ) + dist = Distributed.get_instance() + self.conv = dist.get_disco_conv_s2( + in_chans, + out_chans, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=kernel_shape, + basis_type="morlet", + basis_norm_mode="mean", + groups=math.gcd(in_chans, out_chans), + grid_in=data_grid, + grid_out=data_grid, + bias=False, + theta_cutoff=theta_cutoff, + ) + + def forward(self, x): + return self.conv(x) + + +class AddPosEmbed(nn.Module): + def __init__(self, embed_dim, img_shape): + super().__init__() + h, w = img_shape + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, h, w)) + self.pos_embed.is_shared_mp = ["matmul"] + trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + return x + self.pos_embed + + +def get_lat_lon_ankur_localnet( + params: AnkurLocalNetConfig, + in_chans: int, + out_chans: int, + img_shape: tuple[int, int], + context_config: ContextConfig = ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=0, + embed_dim_labels=0, + embed_dim_pos=0, + ), +) -> "AnkurLocalNet": + return AnkurLocalNet( + params, + img_shape=img_shape, + in_chans=in_chans, + out_chans=out_chans, + ) + + +class AnkurLocalNet(nn.Module): + """Local network replicating Ankur's diagnostic MLP architecture. + + A simple sequential network with 3 hidden layers, optionally using a + DISCO convolution for the first layer and a learned positional embedding. + This is a drop-in replacement for LocalNet with the same forward signature. + + Args: + params: Model configuration. See ``AnkurLocalNetConfig`` for details. + img_shape: Spatial dimensions (lat, lon) of the input data. + in_chans: Number of input channels. + out_chans: Number of output channels. + """ + + def __init__( + self, + params: AnkurLocalNetConfig, + img_shape: tuple[int, int], + in_chans: int, + out_chans: int, + ): + super().__init__() + + activation_functions = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU} + if params.activation_function not in activation_functions: + raise ValueError( + f"Unknown activation function {params.activation_function}" + ) + act_layer = activation_functions[params.activation_function] + + hidden_dim = params.embed_dim + current_dim = in_chans + modules: list[nn.Module] = [] + for i in range(3): + if i == 0 and params.use_disco_encoder: + modules.append( + GroupedDiscreteContinuousConvS2( + current_dim, + hidden_dim, + img_shape=img_shape, + kernel_size=params.disco_kernel_size, + data_grid=params.data_grid, + ) + ) + else: + modules.append(nn.Conv2d(current_dim, hidden_dim, 1, bias=True)) + if i == 0 and params.pos_embed: + modules.append(AddPosEmbed(hidden_dim, img_shape)) + modules.append(act_layer()) + current_dim = hidden_dim + modules.append(nn.Conv2d(current_dim, out_chans, 1, bias=False)) + self.mlp = nn.Sequential(*modules) + + def forward(self, x: torch.Tensor, context: Context): + return self.mlp(x) diff --git a/fme/core/models/conditional_sfno/test_ankur.py b/fme/core/models/conditional_sfno/test_ankur.py new file mode 100644 index 000000000..58b042953 --- /dev/null +++ b/fme/core/models/conditional_sfno/test_ankur.py @@ -0,0 +1,159 @@ +import pytest +import torch + +from fme.core.device import get_device + +from .ankur import AnkurLocalNetConfig, get_lat_lon_ankur_localnet +from .layers import Context, ContextConfig + + +@pytest.mark.parametrize("use_disco_encoder", [True, False]) +@pytest.mark.parametrize("pos_embed", [True, False]) +def test_can_call_ankur_localnet(use_disco_encoder: bool, pos_embed: bool): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = AnkurLocalNetConfig( + embed_dim=16, + use_disco_encoder=use_disco_encoder, + pos_embed=pos_embed, + ) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +def test_ankur_localnet_with_context_config(): + """AnkurLocalNet accepts context_config in factory but ignores context.""" + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = AnkurLocalNetConfig(embed_dim=16) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + context_config=ContextConfig( + embed_dim_scalar=8, + embed_dim_labels=4, + embed_dim_noise=16, + embed_dim_pos=0, + ), + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 8, device=device), + labels=torch.randn(n_samples, 4, device=device), + noise=torch.randn(n_samples, 16, *img_shape, device=device), + embedding_pos=None, + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + +def test_ankur_localnet_context_does_not_affect_output(): + """Verify that different context values produce the same output.""" + torch.manual_seed(0) + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + params = AnkurLocalNetConfig(embed_dim=16) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context1 = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + context2 = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + with torch.no_grad(): + output1 = model(x, context1) + output2 = model(x, context2) + torch.testing.assert_close(output1, output2) + + +def test_ankur_localnet_backward(): + """Test that gradients flow through the network.""" + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + params = AnkurLocalNetConfig(embed_dim=16, use_disco_encoder=True) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + loss = output.sum() + loss.backward() + for name, param in model.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + +def test_ankur_localnet_disco_kernel_size(): + """Test that a non-default kernel size works.""" + input_channels = 4 + output_channels = 3 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + params = AnkurLocalNetConfig( + embed_dim=16, + use_disco_encoder=True, + disco_kernel_size=5, + ) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) From b1567ffe8ed4a438377505ba15b75788a7a6bd29 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 17:02:42 +0000 Subject: [PATCH 04/31] Move data_grid from config dataclasses to factory function arguments data_grid is determined automatically from the data, not configured in YAML, so it should not be an attribute on the config dataclasses. Moved it to an explicit parameter on get_lat_lon_sfnonet, get_lat_lon_localnet, and get_lat_lon_ankur_localnet factory functions. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/stochastic_sfno.py | 2 +- fme/core/models/conditional_sfno/ankur.py | 8 ++++---- fme/core/models/conditional_sfno/localnet.py | 8 ++++---- fme/core/models/conditional_sfno/sfnonet.py | 8 +++----- fme/diffusion/registry/sfno.py | 2 +- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 87e9bf7cd..99bb7f75f 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -269,7 +269,6 @@ def build( pos_embed=self.pos_embed, big_skip=self.big_skip, checkpointing=self.checkpointing, - data_grid=self.data_grid, filter_residual=self.filter_residual, filter_output=self.filter_output, local_blocks=self.local_blocks, @@ -286,6 +285,7 @@ def build( in_chans=n_in_channels, out_chans=n_out_channels, img_shape=dataset_info.img_shape, + data_grid=self.data_grid, context_config=ContextConfig( embed_dim_scalar=0, embed_dim_pos=self.context_pos_embed_dim, diff --git a/fme/core/models/conditional_sfno/ankur.py b/fme/core/models/conditional_sfno/ankur.py index d6b78f0f9..e33ded53e 100644 --- a/fme/core/models/conditional_sfno/ankur.py +++ b/fme/core/models/conditional_sfno/ankur.py @@ -26,8 +26,6 @@ class AnkurLocalNetConfig: pos_embed: Whether to add a learned positional embedding after the first layer. activation_function: Activation function name ('relu', 'gelu', 'silu'). - data_grid: Grid type for DISCO convolutions - ('equiangular', 'legendre-gauss'). """ embed_dim: int = 256 @@ -35,7 +33,6 @@ class AnkurLocalNetConfig: disco_kernel_size: int = 3 pos_embed: bool = False activation_function: str = "gelu" - data_grid: str = "equiangular" class GroupedDiscreteContinuousConvS2(nn.Module): @@ -87,6 +84,7 @@ def get_lat_lon_ankur_localnet( in_chans: int, out_chans: int, img_shape: tuple[int, int], + data_grid: str = "equiangular", context_config: ContextConfig = ContextConfig( embed_dim_scalar=0, embed_dim_noise=0, @@ -99,6 +97,7 @@ def get_lat_lon_ankur_localnet( img_shape=img_shape, in_chans=in_chans, out_chans=out_chans, + data_grid=data_grid, ) @@ -122,6 +121,7 @@ def __init__( img_shape: tuple[int, int], in_chans: int, out_chans: int, + data_grid: str = "equiangular", ): super().__init__() @@ -143,7 +143,7 @@ def __init__( hidden_dim, img_shape=img_shape, kernel_size=params.disco_kernel_size, - data_grid=params.data_grid, + data_grid=data_grid, ) ) else: diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py index 4b62a0f83..2c1d7bee1 100644 --- a/fme/core/models/conditional_sfno/localnet.py +++ b/fme/core/models/conditional_sfno/localnet.py @@ -38,8 +38,6 @@ class LocalNetConfig: lora_rank: Rank of LoRA adaptations. 0 disables LoRA. lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank if None. - data_grid: Grid type for DISCO convolutions - ('equiangular', 'legendre-gauss'). """ embed_dim: int = 256 @@ -57,7 +55,6 @@ class LocalNetConfig: affine_norms: bool = False lora_rank: int = 0 lora_alpha: float | None = None - data_grid: str = "equiangular" def __post_init__(self): valid = get_args(BlockType) @@ -272,6 +269,7 @@ def get_lat_lon_localnet( in_chans: int, out_chans: int, img_shape: tuple[int, int], + data_grid: str = "equiangular", context_config: ContextConfig = ContextConfig( embed_dim_scalar=0, embed_dim_noise=0, @@ -292,6 +290,7 @@ def get_pos_embed(): img_shape=img_shape, in_chans=in_chans, out_chans=out_chans, + data_grid=data_grid, context_config=context_config, get_pos_embed=get_pos_embed, ) @@ -322,6 +321,7 @@ def __init__( get_pos_embed: Callable[[], nn.Parameter], in_chans: int, out_chans: int, + data_grid: str = "equiangular", context_config: ContextConfig = ContextConfig( embed_dim_scalar=0, embed_dim_labels=0, @@ -349,7 +349,7 @@ def __init__( self.affine_norms = params.affine_norms self.lora_rank = params.lora_rank self.lora_alpha = params.lora_alpha - self.data_grid = params.data_grid + self.data_grid = data_grid # determine activation function activation_functions = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU} diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 9b8f1b1dc..1dbc5e7a8 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -85,8 +85,6 @@ class SFNONetConfig: 0 disables LoRA. spectral_lora_alpha: Strength of LoRA adaptations for spectral convolutions. Defaults to spectral_lora_rank if None. - data_grid: Grid type for spherical harmonic transforms - ('equiangular', 'legendre-gauss'). """ embed_dim: int = 256 @@ -114,7 +112,6 @@ class SFNONetConfig: lora_alpha: float | None = None spectral_lora_rank: int = 0 spectral_lora_alpha: float | None = None - data_grid: str = "equiangular" # heuristic for finding theta_cutoff @@ -399,6 +396,7 @@ def get_lat_lon_sfnonet( in_chans: int, out_chans: int, img_shape: Tuple[int, int], + data_grid: str = "equiangular", context_config: ContextConfig = ContextConfig( embed_dim_scalar=0, embed_dim_noise=0, @@ -413,10 +411,10 @@ def get_lat_lon_sfnonet( dist = Distributed.get_instance() trans_down = dist.get_sht( - *img_shape, lmax=modes_lat, mmax=modes_lon, grid=params.data_grid + *img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid ) itrans_up = dist.get_isht( - *img_shape, lmax=modes_lat, mmax=modes_lon, grid=params.data_grid + *img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid ) trans = dist.get_sht( *img_shape, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss" diff --git a/fme/diffusion/registry/sfno.py b/fme/diffusion/registry/sfno.py index d9d93e5cc..26900636e 100644 --- a/fme/diffusion/registry/sfno.py +++ b/fme/diffusion/registry/sfno.py @@ -99,13 +99,13 @@ def build( pos_embed=self.pos_embed, big_skip=self.big_skip, checkpointing=self.checkpointing, - data_grid=self.data_grid, ) sfno_net = get_lat_lon_sfnonet( params=sfno_config, in_chans=n_in_channels, out_chans=n_out_channels, img_shape=img_shape, + data_grid=self.data_grid, context_config=ContextConfig( embed_dim_scalar=n_sigma_embedding_channels, embed_dim_labels=0, From c8950fa6aa901fe3dd41b5ee47914da37e302447 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 23 Mar 2026 20:21:15 +0000 Subject: [PATCH 05/31] Add HybridNet combining SFNO backbone with local diagnostic network Introduces a HybridNet module that compositionally pairs an SFNO backbone (for prognostic output) with a local network (for diagnostic output), analogous to Ankur's ColumnDiagnosticSFNO but using the separated forcing/prognostic/diagnostic interface. Adds type discriminator tags to AnkurLocalNetConfig and LocalNetConfig for union deserialization. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/models/conditional_sfno/ankur.py | 3 + fme/core/models/conditional_sfno/hybrid.py | 168 +++++++++++ fme/core/models/conditional_sfno/localnet.py | 2 + .../models/conditional_sfno/test_hybrid.py | 278 ++++++++++++++++++ .../test_hybridnet_diagnostic_output.pt | Bin 0 -> 12248 bytes .../test_hybridnet_prognostic_output.pt | Bin 0 -> 7064 bytes 6 files changed, 451 insertions(+) create mode 100644 fme/core/models/conditional_sfno/hybrid.py create mode 100644 fme/core/models/conditional_sfno/test_hybrid.py create mode 100644 fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt create mode 100644 fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt diff --git a/fme/core/models/conditional_sfno/ankur.py b/fme/core/models/conditional_sfno/ankur.py index e33ded53e..c90b7e69e 100644 --- a/fme/core/models/conditional_sfno/ankur.py +++ b/fme/core/models/conditional_sfno/ankur.py @@ -1,5 +1,6 @@ import dataclasses import math +from typing import Literal import torch import torch.nn as nn @@ -19,6 +20,7 @@ class AnkurLocalNetConfig: ColumnDiagnosticSphericalFourierNeuralOperatorNet. Attributes: + type: Discriminator tag for union deserialization. embed_dim: Dimension of the hidden layers. use_disco_encoder: Whether to use a DISCO convolution for the first layer instead of a 1x1 convolution. @@ -28,6 +30,7 @@ class AnkurLocalNetConfig: activation_function: Activation function name ('relu', 'gelu', 'silu'). """ + type: Literal["ankurlocalnet"] = "ankurlocalnet" embed_dim: int = 256 use_disco_encoder: bool = False disco_kernel_size: int = 3 diff --git a/fme/core/models/conditional_sfno/hybrid.py b/fme/core/models/conditional_sfno/hybrid.py new file mode 100644 index 000000000..ba96a85c7 --- /dev/null +++ b/fme/core/models/conditional_sfno/hybrid.py @@ -0,0 +1,168 @@ +import dataclasses +from typing import Literal + +import torch +import torch.nn as nn + +from .ankur import AnkurLocalNetConfig, get_lat_lon_ankur_localnet +from .layers import Context, ContextConfig +from .localnet import LocalNetConfig, get_lat_lon_localnet +from .sfnonet import ( + SFNONetConfig, + SphericalFourierNeuralOperatorNet, + get_lat_lon_sfnonet, +) + +LocalConfig = AnkurLocalNetConfig | LocalNetConfig + + +@dataclasses.dataclass +class HybridNetConfig: + """Configuration for HybridNet. + + Attributes: + backbone: Configuration for the SFNO backbone that produces + prognostic output. + local: Configuration for the local network that produces + diagnostic output. Use ``AnkurLocalNetConfig`` (type="ankur") + or ``LocalNetConfig`` (type="localnet"). + learn_residual: Whether to add the prognostic input directly + to the prognostic output (identity skip connection). + data_grid: Grid type for spherical harmonic transforms used + by the SFNO backbone. + """ + + backbone: SFNONetConfig = dataclasses.field(default_factory=SFNONetConfig) + local: LocalConfig = dataclasses.field(default_factory=AnkurLocalNetConfig) + learn_residual: bool = False + data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" + + +def get_lat_lon_hybridnet( + params: HybridNetConfig, + n_forcing_channels: int, + n_prognostic_channels: int, + n_diagnostic_channels: int, + img_shape: tuple[int, int], + embed_dim_labels: int = 0, +) -> "HybridNet": + """Factory function to build a HybridNet. + + Args: + params: HybridNet configuration. + n_forcing_channels: Number of input-only (forcing) channels. + n_prognostic_channels: Number of input-output (prognostic) channels. + n_diagnostic_channels: Number of output-only (diagnostic) channels. + img_shape: Spatial dimensions (lat, lon) of the input data. + embed_dim_labels: Dimension of label embeddings for conditional + layer normalization. 0 disables label conditioning. + + Returns: + A configured HybridNet instance. + """ + n_in = n_forcing_channels + n_prognostic_channels + + context_config = ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=0, + embed_dim_labels=embed_dim_labels, + embed_dim_pos=0, + ) + + backbone = get_lat_lon_sfnonet( + params=params.backbone, + in_chans=n_in, + out_chans=n_prognostic_channels, + img_shape=img_shape, + data_grid=params.data_grid, + context_config=context_config, + ) + + local_config = params.local + if isinstance(local_config, AnkurLocalNetConfig): + local_net: nn.Module = get_lat_lon_ankur_localnet( + params=local_config, + in_chans=n_in, + out_chans=n_diagnostic_channels, + img_shape=img_shape, + data_grid=params.data_grid, + context_config=context_config, + ) + elif isinstance(local_config, LocalNetConfig): + local_net = get_lat_lon_localnet( + params=local_config, + in_chans=n_in, + out_chans=n_diagnostic_channels, + img_shape=img_shape, + data_grid=params.data_grid, + context_config=context_config, + ) + else: + raise ValueError(f"Unknown local config type: {type(local_config)}") + + return HybridNet( + backbone=backbone, + local_net=local_net, + learn_residual=params.learn_residual, + n_prognostic_channels=n_prognostic_channels, + embed_dim_labels=embed_dim_labels, + ) + + +class HybridNet(nn.Module): + """Hybrid network combining an SFNO backbone with a local diagnostic network. + + Analogous to Ankur's ColumnDiagnosticSphericalFourierNeuralOperatorNet, + but using the conditional SFNO and local networks compositionally. + + The SFNO backbone processes the concatenated forcing and prognostic input + to produce a prognostic output. The local network processes the same + concatenated input to produce a diagnostic output. + + Args: + backbone: SFNO network for prognostic prediction. + local_net: Local network for diagnostic prediction. + learn_residual: Whether to add the prognostic input to the + backbone output (identity residual connection). + n_prognostic_channels: Number of prognostic channels, used + to slice the input when learn_residual is True. + embed_dim_labels: Dimension of label embeddings. 0 means no labels. + """ + + def __init__( + self, + backbone: SphericalFourierNeuralOperatorNet, + local_net: nn.Module, + learn_residual: bool = False, + n_prognostic_channels: int = 0, + embed_dim_labels: int = 0, + ): + super().__init__() + self.backbone = backbone + self.local_net = local_net + self.learn_residual = learn_residual + self.n_prognostic_channels = n_prognostic_channels + self.embed_dim_labels = embed_dim_labels + + def forward( + self, + forcing: torch.Tensor, + prognostic: torch.Tensor, + labels: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = torch.cat([forcing, prognostic], dim=1) + + context = Context( + embedding_scalar=None, + embedding_pos=None, + labels=labels, + noise=None, + ) + + prognostic_out = self.backbone(x, context) + diagnostic_out = self.local_net(x, context) + + if self.learn_residual: + prognostic_out = prognostic_out + prognostic + + return prognostic_out, diagnostic_out diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py index 2c1d7bee1..82347707c 100644 --- a/fme/core/models/conditional_sfno/localnet.py +++ b/fme/core/models/conditional_sfno/localnet.py @@ -21,6 +21,7 @@ class LocalNetConfig: """Configuration parameters for LocalNet. Attributes: + type: Discriminator tag for union deserialization. embed_dim: Dimension of the embeddings. block_types: List of filter types for each block ('disco', 'conv1x1'). The length determines the number of blocks. @@ -40,6 +41,7 @@ class LocalNetConfig: if None. """ + type: Literal["localnet"] = "localnet" embed_dim: int = 256 block_types: list[BlockType] = dataclasses.field( default_factory=lambda: ["disco"] * 12 diff --git a/fme/core/models/conditional_sfno/test_hybrid.py b/fme/core/models/conditional_sfno/test_hybrid.py new file mode 100644 index 000000000..acaf5eb26 --- /dev/null +++ b/fme/core/models/conditional_sfno/test_hybrid.py @@ -0,0 +1,278 @@ +import os + +import pytest +import torch + +from fme.core.device import get_device +from fme.core.testing.regression import validate_tensor + +from .ankur import AnkurLocalNetConfig +from .hybrid import HybridNetConfig, get_lat_lon_hybridnet +from .localnet import LocalNetConfig +from .sfnonet import SFNONetConfig + +DIR = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.mark.parametrize( + "learn_residual", + [True, False], +) +def test_can_call_hybridnet_with_ankur_local(learn_residual: bool): + n_forcing = 3 + n_prognostic = 2 + n_diagnostic = 4 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=AnkurLocalNetConfig(embed_dim=16), + learn_residual=learn_residual, + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + prog_out, diag_out = model(forcing, prognostic) + assert prog_out.shape == (n_samples, n_prognostic, *img_shape) + assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) + + +def test_can_call_hybridnet_with_localnet(): + n_forcing = 3 + n_prognostic = 2 + n_diagnostic = 4 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=LocalNetConfig( + embed_dim=16, + block_types=["disco", "disco"], + ), + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + prog_out, diag_out = model(forcing, prognostic) + assert prog_out.shape == (n_samples, n_prognostic, *img_shape) + assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) + + +def test_can_call_hybridnet_with_localnet_conv1x1(): + n_forcing = 3 + n_prognostic = 2 + n_diagnostic = 4 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=LocalNetConfig( + embed_dim=16, + block_types=["conv1x1", "conv1x1"], + ), + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + prog_out, diag_out = model(forcing, prognostic) + assert prog_out.shape == (n_samples, n_prognostic, *img_shape) + assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) + + +def test_hybridnet_with_labels(): + n_forcing = 3 + n_prognostic = 2 + n_diagnostic = 4 + img_shape = (9, 18) + n_samples = 4 + embed_dim_labels = 5 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=AnkurLocalNetConfig(embed_dim=16), + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + embed_dim_labels=embed_dim_labels, + ).to(device) + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + labels = torch.randn(n_samples, embed_dim_labels, device=device) + prog_out, diag_out = model(forcing, prognostic, labels=labels) + assert prog_out.shape == (n_samples, n_prognostic, *img_shape) + assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) + + +def test_learn_residual_adds_prognostic_input(): + """Verify learn_residual adds the prognostic input to the backbone output.""" + torch.manual_seed(0) + n_forcing = 2 + n_prognostic = 3 + n_diagnostic = 2 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + + config_no_residual = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=AnkurLocalNetConfig(embed_dim=16), + learn_residual=False, + ) + config_residual = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=AnkurLocalNetConfig(embed_dim=16), + learn_residual=True, + ) + + model_no = get_lat_lon_hybridnet( + params=config_no_residual, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + model_yes = get_lat_lon_hybridnet( + params=config_residual, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + + # Copy weights from model_no to model_yes + model_yes.load_state_dict(model_no.state_dict()) + + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + + with torch.no_grad(): + prog_no, diag_no = model_no(forcing, prognostic) + prog_yes, diag_yes = model_yes(forcing, prognostic) + + # Diagnostic outputs should be identical + torch.testing.assert_close(diag_no, diag_yes) + # Prognostic output with residual = without residual + prognostic input + torch.testing.assert_close(prog_yes, prog_no + prognostic) + + +def test_backward_pass(): + """Test that gradients flow through both sub-networks.""" + n_forcing = 2 + n_prognostic = 3 + n_diagnostic = 2 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=AnkurLocalNetConfig(embed_dim=16), + learn_residual=True, + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + prog_out, diag_out = model(forcing, prognostic) + loss = prog_out.sum() + diag_out.sum() + loss.backward() + for name, param in model.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + +def test_ankur_disco_encoder(): + """Test HybridNet with AnkurLocalNet using DISCO encoder.""" + n_forcing = 3 + n_prognostic = 2 + n_diagnostic = 4 + img_shape = (9, 18) + n_samples = 2 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), + local=AnkurLocalNetConfig(embed_dim=16, use_disco_encoder=True, pos_embed=True), + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + ).to(device) + forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) + prog_out, diag_out = model(forcing, prognostic) + assert prog_out.shape == (n_samples, n_prognostic, *img_shape) + assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) + + +def setup_hybridnet(): + n_forcing = 3 + n_prognostic = 2 + n_diagnostic = 4 + img_shape = (9, 18) + n_samples = 4 + embed_dim_labels = 3 + device = get_device() + config = HybridNetConfig( + backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="linear"), + local=AnkurLocalNetConfig(embed_dim=16), + learn_residual=True, + ) + model = get_lat_lon_hybridnet( + params=config, + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + img_shape=img_shape, + embed_dim_labels=embed_dim_labels, + ).to(device) + # Initialize on CPU for reproducibility, then move to device + forcing = torch.randn(n_samples, n_forcing, *img_shape).to(device) + prognostic = torch.randn(n_samples, n_prognostic, *img_shape).to(device) + labels = torch.randn(n_samples, embed_dim_labels).to(device) + return model, forcing, prognostic, labels + + +def test_hybridnet_output_is_unchanged(): + torch.manual_seed(0) + model, forcing, prognostic, labels = setup_hybridnet() + with torch.no_grad(): + prog_out, diag_out = model(forcing, prognostic, labels=labels) + validate_tensor( + prog_out, + os.path.join(DIR, "testdata/test_hybridnet_prognostic_output.pt"), + ) + validate_tensor( + diag_out, + os.path.join(DIR, "testdata/test_hybridnet_diagnostic_output.pt"), + ) diff --git a/fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt b/fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt new file mode 100644 index 0000000000000000000000000000000000000000..27205a1e2a057f0c7dcb32f750b99a26d2a25cfc GIT binary patch literal 12248 zcmbVy2{@MD`tE0*=c3F~<~f<(=aEP%G*HoigiuN(rBnz-g9eo*k(3fCA-=Vgk}?$~ ziIfZ>DG_PF`P%#c|DV10Is5$fb=G&S>wT~HUF%-=_pIky>v^AbFLC7IWw8VWS^r1T zW?8br{I-U91w^b14fNgY7v|*~=X23kwX|D(n^Nw{BZtkgr#m-{!5sp-u`|e zTz~H(JTEBNJIwX(4~0Xx-MKhzj6!(ac{o>}E!#qPomGXyLim97!f zE$0K=j+Y#z9l63n1QL9Ff`fwme8K{QH*XbQ9O~;A>gPK*&?ihdL{Qr$L?|po*d;{7 zV2PuEV8=t@6IVFMf9jHdYc(Sl{I%5#)(3}f@(%L~_Y2+18F-$p#s8o2I0iTVL#^n` zn|*)&{Ga;osm)|5{|D6m%gA2dL4p38H~DQ2n`bNk|6pfk`iH>m?IFF~oWcLpe@{S% zrSaDU)Q@WXY%xMaTZF0SW?3jy5u^HrKgsQqD7MZ;LD~_1ne?BiB6&^6 zi0P9O^0}gq)HK=9O({s-FTW&;K{B*-RF$#~GT50J9|#kEMr zntbXz$&(r)Pw)AXWu8Jbet{afbHt7&`=pbDZzjXE20hh0ljtrlBvVp^=~;DSdUyRrlFlzhgG&=ZC)J9kPnk}Sem9^|>io2#bpy?G zm!mniVoTiXH0i5_Z;Aak5!!6UPv;-x#Z$$W^o!Li@}x_ix%6I+4w@by@n+W4q5BaT zr7T)sc9^Jzj1 zhLDFtpGm4f8M!VbOy6vrN*~oskmYjD@Fc;Ideyj77wZo4{FNd-?tY0>X*?z|<+_xg zcbudNzaVb+XVCV`6$C+^IwUI6N~|I7ZP!V~tXFKoh6d7nd=)6=cv8iq3>j{$AF0yCcmIuuX@|bb`4h$EU2z!Z*QnuQZefrx3I*4i zQvd^kgnN$~L`*pY8Yu~oEZqtz!?bO#%&tgL44V9iJgXUF zmdD?O#}Y%#g4vo7=W`ZT&Huz?yak-Rr3mgNC4kI*X_#}x7L(N)*phB98FNt+Xy2G< zX))#s4)3MV_eeDi_vgd)wkS9##6w+#_mtEY`+;-NcBl~g1$#sqiz^J?Le9=?;xZwQ z{X4p?*v2j3DXfng1ES#3odxRZsieqs0;;v@;f7Z{{5)@mXV&cpL6v4pJL?;;e%&N5`FLNfQ?_r;Hg&(<0kwPqz{Ec;yf*Unjnah zO#79n{)z5^if7;n@UZG;gd2;d_#p zTgi>z7G8q@i%r=7h!3YXp98;weAuN_2QtOO@cKd#*zGtB%OA8uR6#XdS#}jx{u%-| z&pu#VC%}zG_3)t96P&n57-+X0(Xl9)*EMofItOQDA&Q=!^LoD#|xNe=MO7w)$nynDyVna$?|1{`@%{ux@+w4NnKl_4lD0yOstg9+n}pk-6}tGRVqIqeD2wu9uZ#p{SM7(LtD3+>=o%dH zOoE=rTOoVTEPOUk41Zd3;m2h>sQa`Zie45$&&e!En*R%C>OO(>`;LNeIUnpaDTlh? zb1=B`3hX^Q86&G3z;?YQHil7n^Ii?Zud-oGDg{KwgmAO}0Wfj2!4YW+m)}>xhsWPQ zMs7YF*!>fHws7OHdJniheGi&RzhV5~5ZI4SMwl}KJAc@~+G&ZPI6M5j?3b`?^<|wpC=y8cRp_{$&~HL4q#+MD%o zYTqX~w_Jc2R~3PZg(f!sJOKr_wqd~?8_3upguGuC;;)Kee1p>1zGgP25*JL&jKIZe zPMFaW3+tCzVYBmZxacH`$xgdbS$P!>elY=3HOwv^wg({3Q-g+j5 z!Uj%QvBC}C7i`1$51cvd^ug61f^g0oHMD!3|MheRKsjz6d}|S805iupEn{VsJ-Z4+LDZh2S(B)Eb;i$Bj3@?fPpew#8Ay`OY}iPl%B7ni`f zdw9^oavjdu?166|Jb+VMUqDNyKXyzx4vjjRSaLxV?H@>BlfOP1=?0^xq8$#i)Ucb| z2G8~Rz~>JNSh!XW2eVrs!h1HZ+&T^If-JB?QJrL|1>$aQWn5hCgP+HI@%(XltQn9- zZMhX7&~A;d6xN~bt%dkLE&@+UY2w2&5u7*bie5uDXx3l|)p1U!!j@*nPg!ANSUU`b zt;CbO+tEA83TG|$Lfw&f5M(_Qy*xEgC&(97GFD+=XBuZccR-bvg(&FejVVs9j9O4b9&S3Yp@Z;`rNQPe;KgWGU444{!4x`DJxDNFivB2mMZAJ7L9TZpwK-HmX18Y$73(%)ajl=|T~nc6 z_D#gUkVQ8Zh|q7xG$@^4Pww=XlMTOI=+HnGGut$mIMz?1pRQ?8Bi9e)qKO;L8NW_e z&Kw}Rdo&4GbuBUf*hS7AbD+FCKas1GZ0P19J}RFr1>C&RB(m3uE|}Rwwks)7vwAJM zz5D}FSgegNzp{zWQ7ziGYASuHltNzH{vBp22r^*Ak95MVDzgPDBt zaPBk0uIMF-T7u-T;{>?^KS*XgA9Wb!qlW6+>C{qXx=F2t)S9J}_r@_qh-;nIFxM1X zXIVwK;`Y)jXV1c=FLHE4r#6-CokYD}8`5o(^T^m9ewwn|omg}xlL4M`GNN&ttmg5c zp~m90P0E%EazC`na-U42r!A&eni_~wax*z`?+TIGWlpV27$W+Whu+o@rt#~w=wN>i zx!L{O>dO39(%jWWmYVRJi#;sbqbyGE?>|a%MwF;~!kH3_`}yppIug{nDv4O2PS_zUXLa;ld7_+24L8}|cmiX~~?ZjAa#%V_zt zm(1%6_aRJI07spg!6!``6~BFex>8?wJ#rVeMH1MzbC}tsRt+m|orYEjWQH zCUq;}`o7zY&39k8u|Nk;C9tsR%P7Quyve@OwGGy-SHq3BE`eOT9o#5QBxM@gf%WVb z98Ww5H(A|`LDqfC=L%xTb{%K$lu*FW^*_M$kTB>?7a+$ry@o+=4JOJ$9Ih8Wfr>wS z#_{(|dFu+W@HhxE9)c+RDxb5SUxmYA3Rn_m3Ka$|%!=JDC7t1^khw`1Kkc{xLYW&F zyNmhE{&}LEFefmYfX za0|*|$MYYAm#hd#tZanzdVlaqJO|UGXFyT<2j;+oQE*RgWXxeZ>@VcRi>qQ7zI#!? z7cFab^Jxb&ToMcCRtgc<-)CWC+i`eddmP-CZU9;LaS)%(z?1!-z@|b9HO{OmT?+(Nyera~; z-fBj1i7?rHZ5UQ9Z-N(tc`zv9fFn||a8g?gbF(i%;_^h89URSc46yLn>2h%X6a{ae zN#fN#(zsTq7C!nVf$)bPu){A09;AEWf$VOWE5wZnmlnadY&EO<@=lPWWDL#iLYT~Y z2fr`2fPVfY?BdRcCPozR9LRyX$yG2hdIBzQ%Z8X)DrjUf2~WEJW;Sm=2AdQcAvoLw z1Ow8cWhe{Y1ISm_o zS@?SPWjG1?DE_7aCiUEg*u$40|0Xxi;`$Aj=N^M~7gX@1!Z{FJ+X7w<+3<=bg-_dL zQCGPZOm!_WIlc&{IBKCL&qLT5$b&+c@?g6DLy&ad41VR;p=cosOFN61xD{_f^XOun z_&Eh{P91>C<)@eu;|DO^gbkZlyI{-6ZBQMrg7h|JREgRLXZB1+uF6XIzC9MkuDig# zjg`O`u7cAp_JT=j0XTer2#JPW;5%O&Z6zwfK0+ARtMtIvzHiWx-UiwVmm$DS17$~B zpy-Y=T85WEQfVt3wbj5DI|0lIlf-35Uqcxqj_GaHz;*FEu;gZ|pWF_C>BaD^ zmk(2Jq(M(<9AvEufSvxOaFr8Fix-}S!9-)QUKt5V=Cv@dG6!mZ$HL;`GezTt-Ua3MKx5Gs-jXxB)mQ{2R~h(fj>40A%EUxOpop-I=idkpu%?$PdCEm zY&#qcV^Cwx7M!$56K#WSaahwG$5TZx@3IxHADWC8+fA@>-4dL4bUJ=>Sc^y2y#aQ( zI$l2$gLZY^2<5Bser+y1>tBF@;%+FlZ69Rls$#wSYB-{1fUXW5AEqRj*Mruku}D1!q} z95CpCJ1(s<#+wIkGvaRqaCor`z8nY4oU{~Uvo_$Fl2qn)*b-DcV2_`&gz@`<`RFKJ z&s=Jl3nr)TfSGt4j*aP|zP$nVc}1|bZv~(xvk0%~nBYvN6mSx}&C(4Mx8(K_~5xpr*A3cg&4I)!}iFaF4@V+ZJQYt!b#D z7KfsnICG|U7gq4iLe_d)yco3|BWsUfMtToy_tHjQ_h}ed^$8%t1!GI+VCIbow9HV) zw|eVwSlJHaKdRuNH?eSoF2&;C9wg|(O7i2ADQ(DBq@{a%7)2~3uclVk!5MOi2E)j+90}vUMq+pgYCZw zPiqG|y|tg5sGdT1w^tKGcuHiWU#z9+9@bQ^ zrJQJ$s?n9^qSX8BJbHZf3@SguM=dP4=(ozNB!*|1zSzm=h;OekAt1Z%8he0}ZYgqP>Pgo##QvLlT?`J`lqD&-lFf`>K!q%e$+zVWRiCN9s2YHk$m`?QdTzL2MN z;j`$1r<1AOusA)s;4QmK+=>PSDO0xzpjuIw zNP{8!nJX`Kls75q&XJ)C3)N`FHG7&_<3#+`^r>{&dt#ozqKP-Yuv^smXndSFmDpEJ zEc(*P;rb`UyLCSGRac@X)`u_|oiWVkLksAfPGcHmwt`lxI?#I!_B1u{J(=(+Al@-2 zN#ETP=-8D4@dxfenW_s6I!r+BhDUJX2^TOn{_w555k%hUfgWoT?$W7(E3XD2L42|auvlHIoBDU>z*L8 zuZu~XrHL`W9>QwdS*R922gej%f@7&E78!4Y(+eW4x;9ypfn#EL%}fJ4e|EyVpPAsC zRS!y7H7Aeb+t30I7pE4nyX~;YY5)#yn~&^^ z(ID$5i!~+sDB5_0xoOY`F-sKaLK?&#@#%!<3~Njny9|z8Pnpn4Me=Co38?)dhEs1= zuqFN80y|_fW@Z#HBRdW-!w=69SD7TxE7OIAoW1wh*Gy>l`OJJf^#Ha%A0ina6|qj8 z%`A7h1o2@@A#2(TczJVC@w$M!?3~I2K<}|YbqyD@`Pl~IYJV9v8ERt1q2o-?X2O6A zkXySw*s~Zf+d-IoH91L;s+IS9PV9kITkgosQF1s-E- z(CBi;>n>kNV`3W|OL)L=N8hm$?li)rzTY6Xpbzd|8iN@rFTp%m4)0FpgFUQbIGs@g zh243e#)%DiXDXwiODkMiD~kNwzhLqG+u)G;iU~b`1$5QT!A5xn8WgobTgW{0T=4)F z^E`!hH{8J>F9!sdX2YXBS#ZCh3G`Gs`Id`waK213@QS?xp%cY0dPy1s+&_R}nlskL ze*^vO2rxeXma$Y{2<1!mz{ZDZ2zG=Sj3J|>fIFBmCsV{X}1NFQ>;N7w5?Q@H`$hR#FH z+6-`ν6W3ZnVco3M1G2n;zki}zu})Yv+>q7(ya+Wo-13qwb_1_<;XgoyLmkgr+{ z_lt!wEWZ*yKCwmqNk`!(cOyV9KRR9(!bXOL_b+?`|E-y@c8(w>ig01tp*lFtu7y+6 z(kx{)Q^6ri8jo`Fp;O*9Lz5pDerPX%@522s0NUX6^fT}ulR=^D9k4*F2rhoDgTRDb z*cG=Av`2(+ypspT)3q>mpA#g;$YLrdzv#d#kD)2TNbFaEGygn@6cxca2lQ}xw*wkC zwt=dp6PEWmprF1O&dAfm?zi)Bb-XN&TKM8~xp{clWC{8<2cea<5yq_aK>p{xXlM1A zDV4Xv*7HLkQt=bsS*=2w@MUOYvmHMs>_VkfAI#mp2Dj>;ffT6$C=`jpnjRbId>n)` z#Eg-3awcZT93{_ZJ%<|?b@5a7KJ4AR4;!2V@b<;8aP5gUis<)2j`I@iubhRxAH%RT zG8%uVDWG~@Fm9Jzgg)a_v3+zKl&s!@HOr^tYU4H7)T4?^?XB^gj~IUGXYjYG5T1GN zk1M`upj_=)aM`a{FKr7NCdu4cS8d$8QgW>1$S_Axf3IoVUd;%UOeXtitG8X{-psn zU30`D#YMPg!xBs$^h9}6bG*3sF|dcMaig>jCY{)aAyKkuw8;xMLNq=(%8w0c9=J#H z7cp9(h%%|m@kQfyth^$QX(zkj^^${(efD&$$zO>2Zrva^W{kfI{E_)4gi$-U!&2T$ z5Vb-JZ!VaI>Pg;stVsIfLnV_xnD4tG);9RTrc2)yc>o{s6>6 zccbi>BpnuWqar`g5FW84(h@d@?%j8h#D%P(;?aUMU_zIsg!xg8Me@{9;|7_PeVla4 zE77ji4Wwt?X(G6wiP(JPqDR~-iOqfiy4!sUEsD-3QlD8=cYO&tYY|GGwh7W!wGOsV zhXNJndq-Z)yhPT#^B{3=dFX=0_cqZ*jP^PKJf0LHVz2x?H3)<)SlVmTw!~TAKA$@j0 zm7MNwBUZb4sO?!J+S;!}T%s+DcfFFN7I)TAXR{baO=mW>Szb>*oOwhttJLW2-eMBH zQ<5sS-yqgE^y!Gw0_xZ3K;v3pkd7(lUk^)aZ>TiF7Em=}RF5J0JzL_Ku z$ASW)<#>fO9=XO8?^U2*{$;j`yok)x)_~3DSntSXO|3Hq@gZbp+Y3E=W?z%TjdZ%unQo2sd4Op3lVC20TpWR0%_?`#jf}(n^h{`Mn#CLnt_Q)>w_)R67T!3K z3TH#Va?bcV;n~p!*z-;gXP+s9(>10T|H8qNFZK(wE}ah>)68J5jU|*P4M0FpI7)cbERRWp!`b?H zsrx=#cdQ=j)h@xNu4l0J%4c|3KEcdWW#J38cbwQ|4%$uSVTQhJhBL{X%%ZscjMlYC z=7fwl@Iu# z8Bn!rC{ZToKrEu2nY~vA_%zPIwh#e;@eX*>!H3VET7j!)Iun@Fc zgk^+Ch=Ce5&AkIRJ$m56JWafQ!VHR)?l5i6aZs^88y5HwJiCf1oc)j-`uPq%jauUGaY0#<8Fe7 zupgZ6C-C7)A>`&jM0$;k)bvL4uRUK9Oay@X}Uk|9sx0f>8^glJ73WHt>! z!uqRVkUJf#xsrj#+v2MgFJO&h3IqnS@Rig|OdKwQgYR=;p413z&F7rmn-yAJ@Vfw6 z4MMo<4IdtS9)V@|BH-}uOyFLw0*{G41`pK09lKK4=GX$}b_>ABhX>CkYQyUr8BEaO zZJ?f21`&-zu=Q&%sJga-N)s0<+lyl9s}xu_7!C7`AHwF>eNgXo8ZJOw%iZr4O26;_H40Tv7wN)kEO4 zN)K;elHW%J<7Df5YW(eYc z1TwY{!MYdV`iCA!n%fIs0w14$V-EGw1!Vv$qi*PPWFWDH=Fe zi4&U~HpP$Y43T%*Bv?=&j< z6<4=8VGAd>m}avKS@VT4T-F6wMC8H!j%g^iWGm7RBV2i!L7OTAjBHwnZ6WKh_PGfj zXbMA{U(1l5T#Zp>ZjgWLJXkD>!*A`vID0OG;^*{mlKEwDIH`jsxdFJ_rT`KjE2Cg< z3XJG}0`2!4jSWXPQ*b{BKZOZk)dCBQaE(Q!jq72Ttu!Wh>*2|X<#>Ov0ebejq3+M8 zu&|0jP!`Aip>N^rNeb+3>d0*H!O9(rFzc=mj&_g1%?dfZ5Iq44?}*^BIa-j+ABKLT zpP+oh7F;4W2^B?8g0yEamKyV8(a}H@I?j)1Vu$8D(=lG|H(0)u1Cv01yzIOYuVEYb zritOjyn}eXARI^2reN#6c1YOd#XL~ihnEdK@%UI25)T$mHh02mK~0pHy$pZlT44K< zAMhCigv03sQYFb-STfyS}fL| z`tOyGS}dKvRzCiZB30wBrH_A;V9pf&!=OIZfB9eQKh0TF|C)ip-^7<&{Xzpd3e=t6 ze}uVSfxh!>C;#6G8P75|u`sr@oMmimG0W7-(<*{G+B;=br{};aKl4>Hm+G_g``U(PK6L3Agb-!Tq~F z_g_K((Fy-d@&7J$y8i*$|8o6*Mg2#oG5_C?oI4l2{{Z#>?#@db`T04$|5lWF^#ALQ w#rmUi|5N_^y60~`%3_WFS7T1Uzr6}K{Cm5<3%>dONS--;1XzFC|Gn-10+338UjP6A literal 0 HcmV?d00001 diff --git a/fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt b/fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt new file mode 100644 index 0000000000000000000000000000000000000000..ecb45d27fa330e2489ff9acf69d66af27e13a7ab GIT binary patch literal 7064 zcmbW62{@J8_s37hL#9NC(1ao(bKbp9l5U|?DotdJ%tN}VWGq7|C6qC#P*Nho+3Pe3 zX+%Xr8c30D8Z`d7_cz>of4}?dKL7PR>v_*P@7~|_J^NkndCvRUwpRQCEEb#1`j42# zGGPU~1qC~LY;g|sboF%$cJvSQ+u-XL6zu8Z=(j1@e^c;GSEpbnZGSIsOLGI(>iJ8la0^bQhr40LneT5-b|Pp0+eVBsf5HX~1Z0 zTPtDq>r=%Z$A^_i()KS`O^0Rn<5kmk_Y3rK3U&-}3k(`AY-w)%|69D_1=k(1R{pvz zYUIn0Bx#Zvwo*8Z``j!xd58+?7-e1k2`$Nvv^hWaB0U~|9=zTx5{`PT-fv($cU zzuoHKmwT!pA`GUi@7gyu5TqWUs9 z^-hXYeB2$*g-%AvcWq>Md>6OHIv>8s{laZ2G=xbHjOk==U2<7$9k$D7!Y{Xzh>znP z$g)i^OOs5aHs24^Q9>{1%#5qdzD5-g@|%o`8^2eJKNIAM+ummC=u%Rqa-8OPsKF$j z3huJ4gR5C5V9wXQBs@+ZhdTJ^-Xt$FvC#(ih-=}Q6o5+eFx)I3Mf;-IJc&4O@Da!X zjeN7&OF#M0P6;u3>yrjfSBvBBOb8*nYjdD-$pCXW)(vhZCD5*{GT8gJof|#b3-;Tx zVe5VakiL78D=9w%^0g)~`4{)&K^YG;PAEi^>JFIBBoW7~1em15f-o!rReo!j(Ncyg zCWWL%Oc>KU(?DSyMLN?Qi!SEElAUL0_VOH*s@qG0^iJT{bP3QBD#Af_BsdDzqS=%{ z{Q3GpWG|A3$ZgRevPvFwZ#9Csj}_?ey2*??wUK(R<4KbA+Gct5C%Fw2jAzSb&LFjq%fp0+_K`5vw#WlTOb? zu=WPx_%)A+jn8!ehdD49-9rn674XYiSK4u5J(hf21z{?mF`^(B#ojd0P@!GW`B;T4 zl3N7t%w=%fg&6WqdOfc9btXB&rj+f!he+27lY35@Xnmp(c6EhA$s|D-7##^!v4 zxP}YA$Ke9EyVUi@AVj?2P=QD3usn`Jb3r!vu#}LknShngL#b(CIiqx7Hkc0O!92}9 z_p4i<=Z1iTTpUxo#2H>$l#)8huhjpM5Kgx;1WxYTYL`nVQDY$+jnzdluYNo7>hDzL z9gD(^Eo0$q#CG_6sDP>$beP!}vvJeQbEte?4boPb(NCfg)Zxqouy4q~fz+Dn0qN(= z-3#NPc=JV6jte7+(tPND%g4{&BH02Ve%)CgSbzeg%5fx&kTufx#2(}v?BDb9t z(27qD90JcVQ_V|AyG$Sks6L`HE*xyK?`(4oUb&IS z-R)sTnSxp}CMc4G96m(f<}AXn+%*UbuW`d-Md{~lJMr#Miabt;VAb6*rf{T97xYt| z82RY==>K{aasSN1e*X>F;$VmS8jMhUUlk~YS;E?!SM_Vidneyaw6<~a*&EG7XY$sfW&enc+0G- zG0#R7OB2rnGq#Gn@9}1QUn!#D`U)_-?}GacU1>~>5=nC9=e2v?$4!@yVO2&E_6+ym zI&E*bIYR)#D-Gbrq^&UFnGIsE2%az+gVk&9&~(mwTvHW>DS9Q;#W)GfB2E(b1N}74 zDjWhhDkOG_1Vp5Sp|byTOu0OobBH+&6{Uf2{zexaotZ+XCO*Ln@!HVrmxt98pTe4p zCb+iu0Djq|M_9xJUf6C2ukKt}t&#y18RwXfcAn_IS%@GKy&}qKr>SXY2_s_^Pe*S{1+~ZX z!1(ShXv`{tnH5rCJWZB&f1wikSllGz!~*GUzf@d33dvFtBMjZN8=vX*Q2jA|p#SPR z4%zi^O=P|?vRn(|JA7aJT6^K^%1-dp@}=D_rBET(YnFVZ5UiW7!0rCAjNS?@aM$8M zU+^h>u5Qm|lsA%;f@0XNHxYd#7)o9v|_Z%o%9xJIJ*;9f7>_ zhGyaSKa**?tMHjzEcqhH!!!BM@#LjCDrU436&@?VMzQsTIK-f?W+S!u6vaF%jw9dB z)q;w7JWaPX1kVjMRM>-qk7RbzkLyBdOk6${Y_$UZldI5bxhQzwnLrn4Y=r2Z6s)L= zL-%4SdT^c^OPq^;!Hz^*w_u#h1n}|DB+Hh_@g{9o$F`wvIAnGKrTdieO3XSG zIiZOAKBZu{LW>@=4IMoZ<9=GExwXZnc7f!vJkoE;yAnJ2(+79ArU$ODdz-0KPUorR=3ix zw{=K^(Ou@DWG->WV(Mui%&|Lp4#H#BLAq)wE*bxZ^vBB3r~RqelgNiJ+{t29Pl=wW2%Cu950d-TKAg*-8pbbRv3gT|ct85ijqL3j|0iH-BY zV<8;ev#SH$N~d!t&szu|?irE8SLZ;e-CUSt-^K_k_~S3HpF@v94^tg@0K-bdiSOzz zd?zvnLN{#Wc4{=?p7;>7N(g|7wK3$ahAoMia|z0CmV#St8kEj6zkJ29w`JPCAFDihyugjT${NCmHo@eYec;ZN)#ia9u`qu+2;@axTnR zP!9Mg2j_av!~O+Z$y9YqjBG#3bZU2!pNdm(sznDnSPOAjFFAB_#0HSdsv=XWMq}rq zI?Q+&iVs@laC>PKI%zINjhZsBe5=86=j$MHWtrgq^B1!5b_=6jHv`{m%Q4p)OUUxD zMYvO=9Da7ng8i+Iu;FDD*UIiKOz@Euuiug1KlpX2ax_YpLtKWo(MmhI{25)L!@#b+cOmQij?@?^Gf8c2E+Q#h)QQSq@m>oP*!q97Ux7f4n6A z3B{X*;Oz}}Zm+x{&NemU^pfLLfA$vQCA1jq>Ktj;LlMyJIEHgq%5tP0uAm<0hlr5j zHYy&x4+O`)!qT_VC~2z5o!+*Ei5!=SrzhLs1?zC|l8ivMY6Gko{;$%Ms>op%R1iOn zU9exsold#Nqpyx%!`f3{==E9$sIJ|FYC|fh&@)IC?x{j{Z6?j@iZxptbqUAqkmb~H zci^5att3C)7siMr0PFn%x-2aO`R8WB43k0V-gSzZbo4OtFEj<4?j+jVUjfY)m&jLp z9-L8J4@(V|Az*hW5qrKGp2o4^3x6WBaaS33$j6a{@*1ezFU?VPT>_G^g`o4~47P6O zF&b`KAUZmfhIYl1+o^uc{WtTmc#0N`N>@R?vES*6$LVzNWhdjLtP5lOBXL2~Dd2O6 zqQcs9pxQ&6)URqq%l8GSty+u|qbzBKcNCots+2H&l96oN92A}y~@xjw<_%OH< z<0aX=#~XIQbjMAozhXU1eItcMGFh3)M zd3>;jp4ZREaj6Cn9vp~A+7=VPIjMA8RRqfGRHDS0XSnm+T9oHu4PDo?s)4fM`!<4HB@x7cX z-uyygkM(ivZ_6UKqrA9VKeXV*m<-|>ITO=*RWWCpJfNT66lCvm~d1u`QiE)iYb^3WQ z=@a63S?|L&1|4+Og*I{__%KF}ZzO$s%jnMnq}p)J3n-;5(7OKwtLl4EveX5RRt_*t zVR5)HxdQVOE8&~vSmb}!K@zW4k)J#tQ`K*aU|o?1o_}J>ox6vHPNp~TN@)RbOcXI~ z-Wjx1e}=Q~DZ{2@K2-cT2JO4T@Wfhq+@9_X&8>ZycFr4D>rY0$^9m5SeipeF<$;y1 zdvQa2IGxxq8H$}s$#n^Sj@epw3|}e4%lDlQKJ^K}lq>R9r18^NUrwO(n$7q*-w4Od zY9rfw*}&I57VEgqMAK~$KAZQ#amAH%xs?pmt2dLK)AZ>}Ee2d)7ogI|2b9g1!8O*m zhj0CvbZz)pUj7Sf*cI^-WOF0v?zvatm##;&aG5UbTOO<*NerHRD_!IHV^9Ed4bc*&XkKrwU zA48L(BQQlY6Yi)df?0_Om9gUE$)B4JLgDf7dbSlx3M1sY^+94v1|E1g1CAZaLdDuN zT>7;HRjc^8Jh*dUA-T3&kaOEe69o%Z z;iJ3;3`R+CPIs(E+n3R#B%%i-GNOpmf@pv$HO{rK8d$D-1SZmhSjs+2WM>B8vf*cD zA0;)oJuRF#EG_`M&%NZa^a03^E~JK63yCHE5!?#%HctrEQwyW!0bgH*1a50xYp;p~S(a(>%=lrcoC&>V-~+p8f^u9_Mzu0#KR2Sz;PF{4(p z7d2S|V38BUJ$fyPy8rZo$Rs|1teq`%_2pR1zLt#vJQrxD0T`M)2X47v$A%;>G>okP zgTj>{Fn&G6?|4j?rX^#Wi70-l3c{${33OM3HBqBB%n8$8SkjyacPD0J$)z+Z5|s>D zHf3BZQDHh&R)$CW-cl{*ZQFePNAnMPYQm7T9glXuvnYuykv-8Vpr))o-+rT}8P<(rv& z{atF3SwrrWNzksytN2{5kZzWf0DblxJh$Zwat>Vsl`=gD(al3=&9mTOpbhy2d0c~j z0giZr6UeMwPGwbxdBu%vH13|p%rtw7FUuBS{H6%pBHc%BE!>Jx+6z$a#sp65U=tHk zZGemGR5`|%_@QU91LT=3f`${y@aSMI)`fl`#YOk&nVL)trhAe7=o0ObEu5X)e~-J7 zPmcukYT%suJjy4?rrCN$B&VpE78KpYFRxgzOyqVZWh@QloGQADfJj%MZMP z4vq|qHIjc#E;U%we@rg_MU>Tkj4ppgsIx>z9Nfv;u>8mT)`+F>V+Y#5!qXtPK+j=< z9O^Uz)*U@vEzM>ByCGczBRykX6B7enU1I}%L;YDgdWJgsCdMN%%TU)u*U(5;U*Fhd z#AU#JWoeCJx;v7#R-;A?|7+*s_jG>R$caD1-`>agv+VE5y3QXY4ToX-4`u%Z`v0u{ zdu%@I59)XRmim9N`ac)=Jrke)hXPB7SMP`R|Hb0}to!?pRsVzTy}zUT8(;sk={5-yKaXL* m_9%;0F Date: Mon, 23 Mar 2026 20:27:11 +0000 Subject: [PATCH 06/31] Remove superfluous HybridNet tests for sub-network internals Co-Authored-By: Claude Opus 4.6 (1M context) --- .../models/conditional_sfno/test_hybrid.py | 54 ------------------- 1 file changed, 54 deletions(-) diff --git a/fme/core/models/conditional_sfno/test_hybrid.py b/fme/core/models/conditional_sfno/test_hybrid.py index acaf5eb26..fa1a61bb8 100644 --- a/fme/core/models/conditional_sfno/test_hybrid.py +++ b/fme/core/models/conditional_sfno/test_hybrid.py @@ -72,34 +72,6 @@ def test_can_call_hybridnet_with_localnet(): assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) -def test_can_call_hybridnet_with_localnet_conv1x1(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=LocalNetConfig( - embed_dim=16, - block_types=["conv1x1", "conv1x1"], - ), - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - def test_hybridnet_with_labels(): n_forcing = 3 n_prognostic = 2 @@ -209,32 +181,6 @@ def test_backward_pass(): assert param.grad is not None, f"No gradient for {name}" -def test_ankur_disco_encoder(): - """Test HybridNet with AnkurLocalNet using DISCO encoder.""" - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 2 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16, use_disco_encoder=True, pos_embed=True), - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - def setup_hybridnet(): n_forcing = 3 n_prognostic = 2 From 1c5b252e7bb0fc75c927a5ffa95e8a1f1247db54 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 16:04:25 +0000 Subject: [PATCH 07/31] Add optional secondary network to SingleModuleStep and remove HybridNet Extend SingleModuleStepConfig with secondary_builder, secondary_out_names, and secondary_residual_names fields to support a secondary network that receives the same input as the backbone. The secondary network can predict diagnostic-only variables and/or add residual corrections to the backbone's prognostic outputs. This subsumes the functionality of HybridNet, which is deleted along with its tests and test data. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/models/conditional_sfno/hybrid.py | 168 ------------ .../models/conditional_sfno/test_hybrid.py | 224 ---------------- .../test_hybridnet_diagnostic_output.pt | Bin 12248 -> 0 bytes .../test_hybridnet_prognostic_output.pt | Bin 7064 -> 0 bytes fme/core/step/single_module.py | 96 ++++++- fme/core/step/test_step.py | 244 ++++++++++++++++++ 6 files changed, 337 insertions(+), 395 deletions(-) delete mode 100644 fme/core/models/conditional_sfno/hybrid.py delete mode 100644 fme/core/models/conditional_sfno/test_hybrid.py delete mode 100644 fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt delete mode 100644 fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt diff --git a/fme/core/models/conditional_sfno/hybrid.py b/fme/core/models/conditional_sfno/hybrid.py deleted file mode 100644 index ba96a85c7..000000000 --- a/fme/core/models/conditional_sfno/hybrid.py +++ /dev/null @@ -1,168 +0,0 @@ -import dataclasses -from typing import Literal - -import torch -import torch.nn as nn - -from .ankur import AnkurLocalNetConfig, get_lat_lon_ankur_localnet -from .layers import Context, ContextConfig -from .localnet import LocalNetConfig, get_lat_lon_localnet -from .sfnonet import ( - SFNONetConfig, - SphericalFourierNeuralOperatorNet, - get_lat_lon_sfnonet, -) - -LocalConfig = AnkurLocalNetConfig | LocalNetConfig - - -@dataclasses.dataclass -class HybridNetConfig: - """Configuration for HybridNet. - - Attributes: - backbone: Configuration for the SFNO backbone that produces - prognostic output. - local: Configuration for the local network that produces - diagnostic output. Use ``AnkurLocalNetConfig`` (type="ankur") - or ``LocalNetConfig`` (type="localnet"). - learn_residual: Whether to add the prognostic input directly - to the prognostic output (identity skip connection). - data_grid: Grid type for spherical harmonic transforms used - by the SFNO backbone. - """ - - backbone: SFNONetConfig = dataclasses.field(default_factory=SFNONetConfig) - local: LocalConfig = dataclasses.field(default_factory=AnkurLocalNetConfig) - learn_residual: bool = False - data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" - - -def get_lat_lon_hybridnet( - params: HybridNetConfig, - n_forcing_channels: int, - n_prognostic_channels: int, - n_diagnostic_channels: int, - img_shape: tuple[int, int], - embed_dim_labels: int = 0, -) -> "HybridNet": - """Factory function to build a HybridNet. - - Args: - params: HybridNet configuration. - n_forcing_channels: Number of input-only (forcing) channels. - n_prognostic_channels: Number of input-output (prognostic) channels. - n_diagnostic_channels: Number of output-only (diagnostic) channels. - img_shape: Spatial dimensions (lat, lon) of the input data. - embed_dim_labels: Dimension of label embeddings for conditional - layer normalization. 0 disables label conditioning. - - Returns: - A configured HybridNet instance. - """ - n_in = n_forcing_channels + n_prognostic_channels - - context_config = ContextConfig( - embed_dim_scalar=0, - embed_dim_noise=0, - embed_dim_labels=embed_dim_labels, - embed_dim_pos=0, - ) - - backbone = get_lat_lon_sfnonet( - params=params.backbone, - in_chans=n_in, - out_chans=n_prognostic_channels, - img_shape=img_shape, - data_grid=params.data_grid, - context_config=context_config, - ) - - local_config = params.local - if isinstance(local_config, AnkurLocalNetConfig): - local_net: nn.Module = get_lat_lon_ankur_localnet( - params=local_config, - in_chans=n_in, - out_chans=n_diagnostic_channels, - img_shape=img_shape, - data_grid=params.data_grid, - context_config=context_config, - ) - elif isinstance(local_config, LocalNetConfig): - local_net = get_lat_lon_localnet( - params=local_config, - in_chans=n_in, - out_chans=n_diagnostic_channels, - img_shape=img_shape, - data_grid=params.data_grid, - context_config=context_config, - ) - else: - raise ValueError(f"Unknown local config type: {type(local_config)}") - - return HybridNet( - backbone=backbone, - local_net=local_net, - learn_residual=params.learn_residual, - n_prognostic_channels=n_prognostic_channels, - embed_dim_labels=embed_dim_labels, - ) - - -class HybridNet(nn.Module): - """Hybrid network combining an SFNO backbone with a local diagnostic network. - - Analogous to Ankur's ColumnDiagnosticSphericalFourierNeuralOperatorNet, - but using the conditional SFNO and local networks compositionally. - - The SFNO backbone processes the concatenated forcing and prognostic input - to produce a prognostic output. The local network processes the same - concatenated input to produce a diagnostic output. - - Args: - backbone: SFNO network for prognostic prediction. - local_net: Local network for diagnostic prediction. - learn_residual: Whether to add the prognostic input to the - backbone output (identity residual connection). - n_prognostic_channels: Number of prognostic channels, used - to slice the input when learn_residual is True. - embed_dim_labels: Dimension of label embeddings. 0 means no labels. - """ - - def __init__( - self, - backbone: SphericalFourierNeuralOperatorNet, - local_net: nn.Module, - learn_residual: bool = False, - n_prognostic_channels: int = 0, - embed_dim_labels: int = 0, - ): - super().__init__() - self.backbone = backbone - self.local_net = local_net - self.learn_residual = learn_residual - self.n_prognostic_channels = n_prognostic_channels - self.embed_dim_labels = embed_dim_labels - - def forward( - self, - forcing: torch.Tensor, - prognostic: torch.Tensor, - labels: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - x = torch.cat([forcing, prognostic], dim=1) - - context = Context( - embedding_scalar=None, - embedding_pos=None, - labels=labels, - noise=None, - ) - - prognostic_out = self.backbone(x, context) - diagnostic_out = self.local_net(x, context) - - if self.learn_residual: - prognostic_out = prognostic_out + prognostic - - return prognostic_out, diagnostic_out diff --git a/fme/core/models/conditional_sfno/test_hybrid.py b/fme/core/models/conditional_sfno/test_hybrid.py deleted file mode 100644 index fa1a61bb8..000000000 --- a/fme/core/models/conditional_sfno/test_hybrid.py +++ /dev/null @@ -1,224 +0,0 @@ -import os - -import pytest -import torch - -from fme.core.device import get_device -from fme.core.testing.regression import validate_tensor - -from .ankur import AnkurLocalNetConfig -from .hybrid import HybridNetConfig, get_lat_lon_hybridnet -from .localnet import LocalNetConfig -from .sfnonet import SFNONetConfig - -DIR = os.path.abspath(os.path.dirname(__file__)) - - -@pytest.mark.parametrize( - "learn_residual", - [True, False], -) -def test_can_call_hybridnet_with_ankur_local(learn_residual: bool): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=learn_residual, - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - -def test_can_call_hybridnet_with_localnet(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=LocalNetConfig( - embed_dim=16, - block_types=["disco", "disco"], - ), - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - -def test_hybridnet_with_labels(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - embed_dim_labels = 5 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - embed_dim_labels=embed_dim_labels, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - labels = torch.randn(n_samples, embed_dim_labels, device=device) - prog_out, diag_out = model(forcing, prognostic, labels=labels) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - -def test_learn_residual_adds_prognostic_input(): - """Verify learn_residual adds the prognostic input to the backbone output.""" - torch.manual_seed(0) - n_forcing = 2 - n_prognostic = 3 - n_diagnostic = 2 - img_shape = (9, 18) - n_samples = 2 - device = get_device() - - config_no_residual = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=False, - ) - config_residual = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=True, - ) - - model_no = get_lat_lon_hybridnet( - params=config_no_residual, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - model_yes = get_lat_lon_hybridnet( - params=config_residual, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - - # Copy weights from model_no to model_yes - model_yes.load_state_dict(model_no.state_dict()) - - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - - with torch.no_grad(): - prog_no, diag_no = model_no(forcing, prognostic) - prog_yes, diag_yes = model_yes(forcing, prognostic) - - # Diagnostic outputs should be identical - torch.testing.assert_close(diag_no, diag_yes) - # Prognostic output with residual = without residual + prognostic input - torch.testing.assert_close(prog_yes, prog_no + prognostic) - - -def test_backward_pass(): - """Test that gradients flow through both sub-networks.""" - n_forcing = 2 - n_prognostic = 3 - n_diagnostic = 2 - img_shape = (9, 18) - n_samples = 2 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=True, - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - loss = prog_out.sum() + diag_out.sum() - loss.backward() - for name, param in model.named_parameters(): - assert param.grad is not None, f"No gradient for {name}" - - -def setup_hybridnet(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - embed_dim_labels = 3 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=True, - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - embed_dim_labels=embed_dim_labels, - ).to(device) - # Initialize on CPU for reproducibility, then move to device - forcing = torch.randn(n_samples, n_forcing, *img_shape).to(device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape).to(device) - labels = torch.randn(n_samples, embed_dim_labels).to(device) - return model, forcing, prognostic, labels - - -def test_hybridnet_output_is_unchanged(): - torch.manual_seed(0) - model, forcing, prognostic, labels = setup_hybridnet() - with torch.no_grad(): - prog_out, diag_out = model(forcing, prognostic, labels=labels) - validate_tensor( - prog_out, - os.path.join(DIR, "testdata/test_hybridnet_prognostic_output.pt"), - ) - validate_tensor( - diag_out, - os.path.join(DIR, "testdata/test_hybridnet_diagnostic_output.pt"), - ) diff --git a/fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt b/fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt deleted file mode 100644 index 27205a1e2a057f0c7dcb32f750b99a26d2a25cfc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12248 zcmbVy2{@MD`tE0*=c3F~<~f<(=aEP%G*HoigiuN(rBnz-g9eo*k(3fCA-=Vgk}?$~ ziIfZ>DG_PF`P%#c|DV10Is5$fb=G&S>wT~HUF%-=_pIky>v^AbFLC7IWw8VWS^r1T zW?8br{I-U91w^b14fNgY7v|*~=X23kwX|D(n^Nw{BZtkgr#m-{!5sp-u`|e zTz~H(JTEBNJIwX(4~0Xx-MKhzj6!(ac{o>}E!#qPomGXyLim97!f zE$0K=j+Y#z9l63n1QL9Ff`fwme8K{QH*XbQ9O~;A>gPK*&?ihdL{Qr$L?|po*d;{7 zV2PuEV8=t@6IVFMf9jHdYc(Sl{I%5#)(3}f@(%L~_Y2+18F-$p#s8o2I0iTVL#^n` zn|*)&{Ga;osm)|5{|D6m%gA2dL4p38H~DQ2n`bNk|6pfk`iH>m?IFF~oWcLpe@{S% zrSaDU)Q@WXY%xMaTZF0SW?3jy5u^HrKgsQqD7MZ;LD~_1ne?BiB6&^6 zi0P9O^0}gq)HK=9O({s-FTW&;K{B*-RF$#~GT50J9|#kEMr zntbXz$&(r)Pw)AXWu8Jbet{afbHt7&`=pbDZzjXE20hh0ljtrlBvVp^=~;DSdUyRrlFlzhgG&=ZC)J9kPnk}Sem9^|>io2#bpy?G zm!mniVoTiXH0i5_Z;Aak5!!6UPv;-x#Z$$W^o!Li@}x_ix%6I+4w@by@n+W4q5BaT zr7T)sc9^Jzj1 zhLDFtpGm4f8M!VbOy6vrN*~oskmYjD@Fc;Ideyj77wZo4{FNd-?tY0>X*?z|<+_xg zcbudNzaVb+XVCV`6$C+^IwUI6N~|I7ZP!V~tXFKoh6d7nd=)6=cv8iq3>j{$AF0yCcmIuuX@|bb`4h$EU2z!Z*QnuQZefrx3I*4i zQvd^kgnN$~L`*pY8Yu~oEZqtz!?bO#%&tgL44V9iJgXUF zmdD?O#}Y%#g4vo7=W`ZT&Huz?yak-Rr3mgNC4kI*X_#}x7L(N)*phB98FNt+Xy2G< zX))#s4)3MV_eeDi_vgd)wkS9##6w+#_mtEY`+;-NcBl~g1$#sqiz^J?Le9=?;xZwQ z{X4p?*v2j3DXfng1ES#3odxRZsieqs0;;v@;f7Z{{5)@mXV&cpL6v4pJL?;;e%&N5`FLNfQ?_r;Hg&(<0kwPqz{Ec;yf*Unjnah zO#79n{)z5^if7;n@UZG;gd2;d_#p zTgi>z7G8q@i%r=7h!3YXp98;weAuN_2QtOO@cKd#*zGtB%OA8uR6#XdS#}jx{u%-| z&pu#VC%}zG_3)t96P&n57-+X0(Xl9)*EMofItOQDA&Q=!^LoD#|xNe=MO7w)$nynDyVna$?|1{`@%{ux@+w4NnKl_4lD0yOstg9+n}pk-6}tGRVqIqeD2wu9uZ#p{SM7(LtD3+>=o%dH zOoE=rTOoVTEPOUk41Zd3;m2h>sQa`Zie45$&&e!En*R%C>OO(>`;LNeIUnpaDTlh? zb1=B`3hX^Q86&G3z;?YQHil7n^Ii?Zud-oGDg{KwgmAO}0Wfj2!4YW+m)}>xhsWPQ zMs7YF*!>fHws7OHdJniheGi&RzhV5~5ZI4SMwl}KJAc@~+G&ZPI6M5j?3b`?^<|wpC=y8cRp_{$&~HL4q#+MD%o zYTqX~w_Jc2R~3PZg(f!sJOKr_wqd~?8_3upguGuC;;)Kee1p>1zGgP25*JL&jKIZe zPMFaW3+tCzVYBmZxacH`$xgdbS$P!>elY=3HOwv^wg({3Q-g+j5 z!Uj%QvBC}C7i`1$51cvd^ug61f^g0oHMD!3|MheRKsjz6d}|S805iupEn{VsJ-Z4+LDZh2S(B)Eb;i$Bj3@?fPpew#8Ay`OY}iPl%B7ni`f zdw9^oavjdu?166|Jb+VMUqDNyKXyzx4vjjRSaLxV?H@>BlfOP1=?0^xq8$#i)Ucb| z2G8~Rz~>JNSh!XW2eVrs!h1HZ+&T^If-JB?QJrL|1>$aQWn5hCgP+HI@%(XltQn9- zZMhX7&~A;d6xN~bt%dkLE&@+UY2w2&5u7*bie5uDXx3l|)p1U!!j@*nPg!ANSUU`b zt;CbO+tEA83TG|$Lfw&f5M(_Qy*xEgC&(97GFD+=XBuZccR-bvg(&FejVVs9j9O4b9&S3Yp@Z;`rNQPe;KgWGU444{!4x`DJxDNFivB2mMZAJ7L9TZpwK-HmX18Y$73(%)ajl=|T~nc6 z_D#gUkVQ8Zh|q7xG$@^4Pww=XlMTOI=+HnGGut$mIMz?1pRQ?8Bi9e)qKO;L8NW_e z&Kw}Rdo&4GbuBUf*hS7AbD+FCKas1GZ0P19J}RFr1>C&RB(m3uE|}Rwwks)7vwAJM zz5D}FSgegNzp{zWQ7ziGYASuHltNzH{vBp22r^*Ak95MVDzgPDBt zaPBk0uIMF-T7u-T;{>?^KS*XgA9Wb!qlW6+>C{qXx=F2t)S9J}_r@_qh-;nIFxM1X zXIVwK;`Y)jXV1c=FLHE4r#6-CokYD}8`5o(^T^m9ewwn|omg}xlL4M`GNN&ttmg5c zp~m90P0E%EazC`na-U42r!A&eni_~wax*z`?+TIGWlpV27$W+Whu+o@rt#~w=wN>i zx!L{O>dO39(%jWWmYVRJi#;sbqbyGE?>|a%MwF;~!kH3_`}yppIug{nDv4O2PS_zUXLa;ld7_+24L8}|cmiX~~?ZjAa#%V_zt zm(1%6_aRJI07spg!6!``6~BFex>8?wJ#rVeMH1MzbC}tsRt+m|orYEjWQH zCUq;}`o7zY&39k8u|Nk;C9tsR%P7Quyve@OwGGy-SHq3BE`eOT9o#5QBxM@gf%WVb z98Ww5H(A|`LDqfC=L%xTb{%K$lu*FW^*_M$kTB>?7a+$ry@o+=4JOJ$9Ih8Wfr>wS z#_{(|dFu+W@HhxE9)c+RDxb5SUxmYA3Rn_m3Ka$|%!=JDC7t1^khw`1Kkc{xLYW&F zyNmhE{&}LEFefmYfX za0|*|$MYYAm#hd#tZanzdVlaqJO|UGXFyT<2j;+oQE*RgWXxeZ>@VcRi>qQ7zI#!? z7cFab^Jxb&ToMcCRtgc<-)CWC+i`eddmP-CZU9;LaS)%(z?1!-z@|b9HO{OmT?+(Nyera~; z-fBj1i7?rHZ5UQ9Z-N(tc`zv9fFn||a8g?gbF(i%;_^h89URSc46yLn>2h%X6a{ae zN#fN#(zsTq7C!nVf$)bPu){A09;AEWf$VOWE5wZnmlnadY&EO<@=lPWWDL#iLYT~Y z2fr`2fPVfY?BdRcCPozR9LRyX$yG2hdIBzQ%Z8X)DrjUf2~WEJW;Sm=2AdQcAvoLw z1Ow8cWhe{Y1ISm_o zS@?SPWjG1?DE_7aCiUEg*u$40|0Xxi;`$Aj=N^M~7gX@1!Z{FJ+X7w<+3<=bg-_dL zQCGPZOm!_WIlc&{IBKCL&qLT5$b&+c@?g6DLy&ad41VR;p=cosOFN61xD{_f^XOun z_&Eh{P91>C<)@eu;|DO^gbkZlyI{-6ZBQMrg7h|JREgRLXZB1+uF6XIzC9MkuDig# zjg`O`u7cAp_JT=j0XTer2#JPW;5%O&Z6zwfK0+ARtMtIvzHiWx-UiwVmm$DS17$~B zpy-Y=T85WEQfVt3wbj5DI|0lIlf-35Uqcxqj_GaHz;*FEu;gZ|pWF_C>BaD^ zmk(2Jq(M(<9AvEufSvxOaFr8Fix-}S!9-)QUKt5V=Cv@dG6!mZ$HL;`GezTt-Ua3MKx5Gs-jXxB)mQ{2R~h(fj>40A%EUxOpop-I=idkpu%?$PdCEm zY&#qcV^Cwx7M!$56K#WSaahwG$5TZx@3IxHADWC8+fA@>-4dL4bUJ=>Sc^y2y#aQ( zI$l2$gLZY^2<5Bser+y1>tBF@;%+FlZ69Rls$#wSYB-{1fUXW5AEqRj*Mruku}D1!q} z95CpCJ1(s<#+wIkGvaRqaCor`z8nY4oU{~Uvo_$Fl2qn)*b-DcV2_`&gz@`<`RFKJ z&s=Jl3nr)TfSGt4j*aP|zP$nVc}1|bZv~(xvk0%~nBYvN6mSx}&C(4Mx8(K_~5xpr*A3cg&4I)!}iFaF4@V+ZJQYt!b#D z7KfsnICG|U7gq4iLe_d)yco3|BWsUfMtToy_tHjQ_h}ed^$8%t1!GI+VCIbow9HV) zw|eVwSlJHaKdRuNH?eSoF2&;C9wg|(O7i2ADQ(DBq@{a%7)2~3uclVk!5MOi2E)j+90}vUMq+pgYCZw zPiqG|y|tg5sGdT1w^tKGcuHiWU#z9+9@bQ^ zrJQJ$s?n9^qSX8BJbHZf3@SguM=dP4=(ozNB!*|1zSzm=h;OekAt1Z%8he0}ZYgqP>Pgo##QvLlT?`J`lqD&-lFf`>K!q%e$+zVWRiCN9s2YHk$m`?QdTzL2MN z;j`$1r<1AOusA)s;4QmK+=>PSDO0xzpjuIw zNP{8!nJX`Kls75q&XJ)C3)N`FHG7&_<3#+`^r>{&dt#ozqKP-Yuv^smXndSFmDpEJ zEc(*P;rb`UyLCSGRac@X)`u_|oiWVkLksAfPGcHmwt`lxI?#I!_B1u{J(=(+Al@-2 zN#ETP=-8D4@dxfenW_s6I!r+BhDUJX2^TOn{_w555k%hUfgWoT?$W7(E3XD2L42|auvlHIoBDU>z*L8 zuZu~XrHL`W9>QwdS*R922gej%f@7&E78!4Y(+eW4x;9ypfn#EL%}fJ4e|EyVpPAsC zRS!y7H7Aeb+t30I7pE4nyX~;YY5)#yn~&^^ z(ID$5i!~+sDB5_0xoOY`F-sKaLK?&#@#%!<3~Njny9|z8Pnpn4Me=Co38?)dhEs1= zuqFN80y|_fW@Z#HBRdW-!w=69SD7TxE7OIAoW1wh*Gy>l`OJJf^#Ha%A0ina6|qj8 z%`A7h1o2@@A#2(TczJVC@w$M!?3~I2K<}|YbqyD@`Pl~IYJV9v8ERt1q2o-?X2O6A zkXySw*s~Zf+d-IoH91L;s+IS9PV9kITkgosQF1s-E- z(CBi;>n>kNV`3W|OL)L=N8hm$?li)rzTY6Xpbzd|8iN@rFTp%m4)0FpgFUQbIGs@g zh243e#)%DiXDXwiODkMiD~kNwzhLqG+u)G;iU~b`1$5QT!A5xn8WgobTgW{0T=4)F z^E`!hH{8J>F9!sdX2YXBS#ZCh3G`Gs`Id`waK213@QS?xp%cY0dPy1s+&_R}nlskL ze*^vO2rxeXma$Y{2<1!mz{ZDZ2zG=Sj3J|>fIFBmCsV{X}1NFQ>;N7w5?Q@H`$hR#FH z+6-`ν6W3ZnVco3M1G2n;zki}zu})Yv+>q7(ya+Wo-13qwb_1_<;XgoyLmkgr+{ z_lt!wEWZ*yKCwmqNk`!(cOyV9KRR9(!bXOL_b+?`|E-y@c8(w>ig01tp*lFtu7y+6 z(kx{)Q^6ri8jo`Fp;O*9Lz5pDerPX%@522s0NUX6^fT}ulR=^D9k4*F2rhoDgTRDb z*cG=Av`2(+ypspT)3q>mpA#g;$YLrdzv#d#kD)2TNbFaEGygn@6cxca2lQ}xw*wkC zwt=dp6PEWmprF1O&dAfm?zi)Bb-XN&TKM8~xp{clWC{8<2cea<5yq_aK>p{xXlM1A zDV4Xv*7HLkQt=bsS*=2w@MUOYvmHMs>_VkfAI#mp2Dj>;ffT6$C=`jpnjRbId>n)` z#Eg-3awcZT93{_ZJ%<|?b@5a7KJ4AR4;!2V@b<;8aP5gUis<)2j`I@iubhRxAH%RT zG8%uVDWG~@Fm9Jzgg)a_v3+zKl&s!@HOr^tYU4H7)T4?^?XB^gj~IUGXYjYG5T1GN zk1M`upj_=)aM`a{FKr7NCdu4cS8d$8QgW>1$S_Axf3IoVUd;%UOeXtitG8X{-psn zU30`D#YMPg!xBs$^h9}6bG*3sF|dcMaig>jCY{)aAyKkuw8;xMLNq=(%8w0c9=J#H z7cp9(h%%|m@kQfyth^$QX(zkj^^${(efD&$$zO>2Zrva^W{kfI{E_)4gi$-U!&2T$ z5Vb-JZ!VaI>Pg;stVsIfLnV_xnD4tG);9RTrc2)yc>o{s6>6 zccbi>BpnuWqar`g5FW84(h@d@?%j8h#D%P(;?aUMU_zIsg!xg8Me@{9;|7_PeVla4 zE77ji4Wwt?X(G6wiP(JPqDR~-iOqfiy4!sUEsD-3QlD8=cYO&tYY|GGwh7W!wGOsV zhXNJndq-Z)yhPT#^B{3=dFX=0_cqZ*jP^PKJf0LHVz2x?H3)<)SlVmTw!~TAKA$@j0 zm7MNwBUZb4sO?!J+S;!}T%s+DcfFFN7I)TAXR{baO=mW>Szb>*oOwhttJLW2-eMBH zQ<5sS-yqgE^y!Gw0_xZ3K;v3pkd7(lUk^)aZ>TiF7Em=}RF5J0JzL_Ku z$ASW)<#>fO9=XO8?^U2*{$;j`yok)x)_~3DSntSXO|3Hq@gZbp+Y3E=W?z%TjdZ%unQo2sd4Op3lVC20TpWR0%_?`#jf}(n^h{`Mn#CLnt_Q)>w_)R67T!3K z3TH#Va?bcV;n~p!*z-;gXP+s9(>10T|H8qNFZK(wE}ah>)68J5jU|*P4M0FpI7)cbERRWp!`b?H zsrx=#cdQ=j)h@xNu4l0J%4c|3KEcdWW#J38cbwQ|4%$uSVTQhJhBL{X%%ZscjMlYC z=7fwl@Iu# z8Bn!rC{ZToKrEu2nY~vA_%zPIwh#e;@eX*>!H3VET7j!)Iun@Fc zgk^+Ch=Ce5&AkIRJ$m56JWafQ!VHR)?l5i6aZs^88y5HwJiCf1oc)j-`uPq%jauUGaY0#<8Fe7 zupgZ6C-C7)A>`&jM0$;k)bvL4uRUK9Oay@X}Uk|9sx0f>8^glJ73WHt>! z!uqRVkUJf#xsrj#+v2MgFJO&h3IqnS@Rig|OdKwQgYR=;p413z&F7rmn-yAJ@Vfw6 z4MMo<4IdtS9)V@|BH-}uOyFLw0*{G41`pK09lKK4=GX$}b_>ABhX>CkYQyUr8BEaO zZJ?f21`&-zu=Q&%sJga-N)s0<+lyl9s}xu_7!C7`AHwF>eNgXo8ZJOw%iZr4O26;_H40Tv7wN)kEO4 zN)K;elHW%J<7Df5YW(eYc z1TwY{!MYdV`iCA!n%fIs0w14$V-EGw1!Vv$qi*PPWFWDH=Fe zi4&U~HpP$Y43T%*Bv?=&j< z6<4=8VGAd>m}avKS@VT4T-F6wMC8H!j%g^iWGm7RBV2i!L7OTAjBHwnZ6WKh_PGfj zXbMA{U(1l5T#Zp>ZjgWLJXkD>!*A`vID0OG;^*{mlKEwDIH`jsxdFJ_rT`KjE2Cg< z3XJG}0`2!4jSWXPQ*b{BKZOZk)dCBQaE(Q!jq72Ttu!Wh>*2|X<#>Ov0ebejq3+M8 zu&|0jP!`Aip>N^rNeb+3>d0*H!O9(rFzc=mj&_g1%?dfZ5Iq44?}*^BIa-j+ABKLT zpP+oh7F;4W2^B?8g0yEamKyV8(a}H@I?j)1Vu$8D(=lG|H(0)u1Cv01yzIOYuVEYb zritOjyn}eXARI^2reN#6c1YOd#XL~ihnEdK@%UI25)T$mHh02mK~0pHy$pZlT44K< zAMhCigv03sQYFb-STfyS}fL| z`tOyGS}dKvRzCiZB30wBrH_A;V9pf&!=OIZfB9eQKh0TF|C)ip-^7<&{Xzpd3e=t6 ze}uVSfxh!>C;#6G8P75|u`sr@oMmimG0W7-(<*{G+B;=br{};aKl4>Hm+G_g``U(PK6L3Agb-!Tq~F z_g_K((Fy-d@&7J$y8i*$|8o6*Mg2#oG5_C?oI4l2{{Z#>?#@db`T04$|5lWF^#ALQ w#rmUi|5N_^y60~`%3_WFS7T1Uzr6}K{Cm5<3%>dONS--;1XzFC|Gn-10+338UjP6A diff --git a/fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt b/fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt deleted file mode 100644 index ecb45d27fa330e2489ff9acf69d66af27e13a7ab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7064 zcmbW62{@J8_s37hL#9NC(1ao(bKbp9l5U|?DotdJ%tN}VWGq7|C6qC#P*Nho+3Pe3 zX+%Xr8c30D8Z`d7_cz>of4}?dKL7PR>v_*P@7~|_J^NkndCvRUwpRQCEEb#1`j42# zGGPU~1qC~LY;g|sboF%$cJvSQ+u-XL6zu8Z=(j1@e^c;GSEpbnZGSIsOLGI(>iJ8la0^bQhr40LneT5-b|Pp0+eVBsf5HX~1Z0 zTPtDq>r=%Z$A^_i()KS`O^0Rn<5kmk_Y3rK3U&-}3k(`AY-w)%|69D_1=k(1R{pvz zYUIn0Bx#Zvwo*8Z``j!xd58+?7-e1k2`$Nvv^hWaB0U~|9=zTx5{`PT-fv($cU zzuoHKmwT!pA`GUi@7gyu5TqWUs9 z^-hXYeB2$*g-%AvcWq>Md>6OHIv>8s{laZ2G=xbHjOk==U2<7$9k$D7!Y{Xzh>znP z$g)i^OOs5aHs24^Q9>{1%#5qdzD5-g@|%o`8^2eJKNIAM+ummC=u%Rqa-8OPsKF$j z3huJ4gR5C5V9wXQBs@+ZhdTJ^-Xt$FvC#(ih-=}Q6o5+eFx)I3Mf;-IJc&4O@Da!X zjeN7&OF#M0P6;u3>yrjfSBvBBOb8*nYjdD-$pCXW)(vhZCD5*{GT8gJof|#b3-;Tx zVe5VakiL78D=9w%^0g)~`4{)&K^YG;PAEi^>JFIBBoW7~1em15f-o!rReo!j(Ncyg zCWWL%Oc>KU(?DSyMLN?Qi!SEElAUL0_VOH*s@qG0^iJT{bP3QBD#Af_BsdDzqS=%{ z{Q3GpWG|A3$ZgRevPvFwZ#9Csj}_?ey2*??wUK(R<4KbA+Gct5C%Fw2jAzSb&LFjq%fp0+_K`5vw#WlTOb? zu=WPx_%)A+jn8!ehdD49-9rn674XYiSK4u5J(hf21z{?mF`^(B#ojd0P@!GW`B;T4 zl3N7t%w=%fg&6WqdOfc9btXB&rj+f!he+27lY35@Xnmp(c6EhA$s|D-7##^!v4 zxP}YA$Ke9EyVUi@AVj?2P=QD3usn`Jb3r!vu#}LknShngL#b(CIiqx7Hkc0O!92}9 z_p4i<=Z1iTTpUxo#2H>$l#)8huhjpM5Kgx;1WxYTYL`nVQDY$+jnzdluYNo7>hDzL z9gD(^Eo0$q#CG_6sDP>$beP!}vvJeQbEte?4boPb(NCfg)Zxqouy4q~fz+Dn0qN(= z-3#NPc=JV6jte7+(tPND%g4{&BH02Ve%)CgSbzeg%5fx&kTufx#2(}v?BDb9t z(27qD90JcVQ_V|AyG$Sks6L`HE*xyK?`(4oUb&IS z-R)sTnSxp}CMc4G96m(f<}AXn+%*UbuW`d-Md{~lJMr#Miabt;VAb6*rf{T97xYt| z82RY==>K{aasSN1e*X>F;$VmS8jMhUUlk~YS;E?!SM_Vidneyaw6<~a*&EG7XY$sfW&enc+0G- zG0#R7OB2rnGq#Gn@9}1QUn!#D`U)_-?}GacU1>~>5=nC9=e2v?$4!@yVO2&E_6+ym zI&E*bIYR)#D-Gbrq^&UFnGIsE2%az+gVk&9&~(mwTvHW>DS9Q;#W)GfB2E(b1N}74 zDjWhhDkOG_1Vp5Sp|byTOu0OobBH+&6{Uf2{zexaotZ+XCO*Ln@!HVrmxt98pTe4p zCb+iu0Djq|M_9xJUf6C2ukKt}t&#y18RwXfcAn_IS%@GKy&}qKr>SXY2_s_^Pe*S{1+~ZX z!1(ShXv`{tnH5rCJWZB&f1wikSllGz!~*GUzf@d33dvFtBMjZN8=vX*Q2jA|p#SPR z4%zi^O=P|?vRn(|JA7aJT6^K^%1-dp@}=D_rBET(YnFVZ5UiW7!0rCAjNS?@aM$8M zU+^h>u5Qm|lsA%;f@0XNHxYd#7)o9v|_Z%o%9xJIJ*;9f7>_ zhGyaSKa**?tMHjzEcqhH!!!BM@#LjCDrU436&@?VMzQsTIK-f?W+S!u6vaF%jw9dB z)q;w7JWaPX1kVjMRM>-qk7RbzkLyBdOk6${Y_$UZldI5bxhQzwnLrn4Y=r2Z6s)L= zL-%4SdT^c^OPq^;!Hz^*w_u#h1n}|DB+Hh_@g{9o$F`wvIAnGKrTdieO3XSG zIiZOAKBZu{LW>@=4IMoZ<9=GExwXZnc7f!vJkoE;yAnJ2(+79ArU$ODdz-0KPUorR=3ix zw{=K^(Ou@DWG->WV(Mui%&|Lp4#H#BLAq)wE*bxZ^vBB3r~RqelgNiJ+{t29Pl=wW2%Cu950d-TKAg*-8pbbRv3gT|ct85ijqL3j|0iH-BY zV<8;ev#SH$N~d!t&szu|?irE8SLZ;e-CUSt-^K_k_~S3HpF@v94^tg@0K-bdiSOzz zd?zvnLN{#Wc4{=?p7;>7N(g|7wK3$ahAoMia|z0CmV#St8kEj6zkJ29w`JPCAFDihyugjT${NCmHo@eYec;ZN)#ia9u`qu+2;@axTnR zP!9Mg2j_av!~O+Z$y9YqjBG#3bZU2!pNdm(sznDnSPOAjFFAB_#0HSdsv=XWMq}rq zI?Q+&iVs@laC>PKI%zINjhZsBe5=86=j$MHWtrgq^B1!5b_=6jHv`{m%Q4p)OUUxD zMYvO=9Da7ng8i+Iu;FDD*UIiKOz@Euuiug1KlpX2ax_YpLtKWo(MmhI{25)L!@#b+cOmQij?@?^Gf8c2E+Q#h)QQSq@m>oP*!q97Ux7f4n6A z3B{X*;Oz}}Zm+x{&NemU^pfLLfA$vQCA1jq>Ktj;LlMyJIEHgq%5tP0uAm<0hlr5j zHYy&x4+O`)!qT_VC~2z5o!+*Ei5!=SrzhLs1?zC|l8ivMY6Gko{;$%Ms>op%R1iOn zU9exsold#Nqpyx%!`f3{==E9$sIJ|FYC|fh&@)IC?x{j{Z6?j@iZxptbqUAqkmb~H zci^5att3C)7siMr0PFn%x-2aO`R8WB43k0V-gSzZbo4OtFEj<4?j+jVUjfY)m&jLp z9-L8J4@(V|Az*hW5qrKGp2o4^3x6WBaaS33$j6a{@*1ezFU?VPT>_G^g`o4~47P6O zF&b`KAUZmfhIYl1+o^uc{WtTmc#0N`N>@R?vES*6$LVzNWhdjLtP5lOBXL2~Dd2O6 zqQcs9pxQ&6)URqq%l8GSty+u|qbzBKcNCots+2H&l96oN92A}y~@xjw<_%OH< z<0aX=#~XIQbjMAozhXU1eItcMGFh3)M zd3>;jp4ZREaj6Cn9vp~A+7=VPIjMA8RRqfGRHDS0XSnm+T9oHu4PDo?s)4fM`!<4HB@x7cX z-uyygkM(ivZ_6UKqrA9VKeXV*m<-|>ITO=*RWWCpJfNT66lCvm~d1u`QiE)iYb^3WQ z=@a63S?|L&1|4+Og*I{__%KF}ZzO$s%jnMnq}p)J3n-;5(7OKwtLl4EveX5RRt_*t zVR5)HxdQVOE8&~vSmb}!K@zW4k)J#tQ`K*aU|o?1o_}J>ox6vHPNp~TN@)RbOcXI~ z-Wjx1e}=Q~DZ{2@K2-cT2JO4T@Wfhq+@9_X&8>ZycFr4D>rY0$^9m5SeipeF<$;y1 zdvQa2IGxxq8H$}s$#n^Sj@epw3|}e4%lDlQKJ^K}lq>R9r18^NUrwO(n$7q*-w4Od zY9rfw*}&I57VEgqMAK~$KAZQ#amAH%xs?pmt2dLK)AZ>}Ee2d)7ogI|2b9g1!8O*m zhj0CvbZz)pUj7Sf*cI^-WOF0v?zvatm##;&aG5UbTOO<*NerHRD_!IHV^9Ed4bc*&XkKrwU zA48L(BQQlY6Yi)df?0_Om9gUE$)B4JLgDf7dbSlx3M1sY^+94v1|E1g1CAZaLdDuN zT>7;HRjc^8Jh*dUA-T3&kaOEe69o%Z z;iJ3;3`R+CPIs(E+n3R#B%%i-GNOpmf@pv$HO{rK8d$D-1SZmhSjs+2WM>B8vf*cD zA0;)oJuRF#EG_`M&%NZa^a03^E~JK63yCHE5!?#%HctrEQwyW!0bgH*1a50xYp;p~S(a(>%=lrcoC&>V-~+p8f^u9_Mzu0#KR2Sz;PF{4(p z7d2S|V38BUJ$fyPy8rZo$Rs|1teq`%_2pR1zLt#vJQrxD0T`M)2X47v$A%;>G>okP zgTj>{Fn&G6?|4j?rX^#Wi70-l3c{${33OM3HBqBB%n8$8SkjyacPD0J$)z+Z5|s>D zHf3BZQDHh&R)$CW-cl{*ZQFePNAnMPYQm7T9glXuvnYuykv-8Vpr))o-+rT}8P<(rv& z{atF3SwrrWNzksytN2{5kZzWf0DblxJh$Zwat>Vsl`=gD(al3=&9mTOpbhy2d0c~j z0giZr6UeMwPGwbxdBu%vH13|p%rtw7FUuBS{H6%pBHc%BE!>Jx+6z$a#sp65U=tHk zZGemGR5`|%_@QU91LT=3f`${y@aSMI)`fl`#YOk&nVL)trhAe7=o0ObEu5X)e~-J7 zPmcukYT%suJjy4?rrCN$B&VpE78KpYFRxgzOyqVZWh@QloGQADfJj%MZMP z4vq|qHIjc#E;U%we@rg_MU>Tkj4ppgsIx>z9Nfv;u>8mT)`+F>V+Y#5!qXtPK+j=< z9O^Uz)*U@vEzM>ByCGczBRykX6B7enU1I}%L;YDgdWJgsCdMN%%TU)u*U(5;U*Fhd z#AU#JWoeCJx;v7#R-;A?|7+*s_jG>R$caD1-`>agv+VE5y3QXY4ToX-4`u%Z`v0u{ zdu%@I59)XRmim9N`ac)=Jrke)hXPB7SMP`R|Hb0}to!?pRsVzTy}zUT8(;sk={5-yKaXL* m_9%;0F int: @@ -151,12 +200,16 @@ def diagnostic_names(self) -> list[str]: @property def output_names(self) -> list[str]: - secondary_names = ( + secondary_decoder_names = ( self.secondary_decoder.secondary_diagnostic_names if self.secondary_decoder is not None else [] ) - return list(set(self.out_names).union(secondary_names)) + return list( + set(self.out_names) + .union(secondary_decoder_names) + .union(self.secondary_out_names) + ) @property def next_step_input_names(self) -> list[str]: @@ -267,6 +320,20 @@ def __init__( dist = Distributed.get_instance() + if config.secondary_builder is not None: + secondary_module = config.secondary_builder.build( + n_in_channels=n_in_channels, + n_out_channels=len(config.secondary_out_names), + dataset_info=dataset_info, + ) + self.secondary_module: Module | None = secondary_module.to(get_device()) + self.secondary_out_packer: Packer | None = Packer( + config.secondary_out_names + ) + else: + self.secondary_module = None + self.secondary_out_packer = None + if config.secondary_decoder is not None: self.secondary_decoder: SecondaryDecoder | NoSecondaryDecoder = ( config.secondary_decoder.build( @@ -282,6 +349,8 @@ def __init__( self._no_optimization = NullOptimization() self.module = self.module.wrap_module(dist.wrap_module) + if self.secondary_module is not None: + self.secondary_module = self.secondary_module.wrap_module(dist.wrap_module) self.secondary_decoder = self.secondary_decoder.wrap_module(dist.wrap_module) self._timestep = dataset_info.timestep @@ -329,6 +398,8 @@ def modules(self) -> nn.ModuleList: A list of modules being trained. """ modules = [self.module.torch_module] + if self.secondary_module is not None: + modules.append(self.secondary_module.torch_module) modules.extend(self.secondary_decoder.torch_modules) return nn.ModuleList(modules) @@ -355,6 +426,21 @@ def network_call(input_norm: TensorDict) -> TensorDict: labels=args.labels, ) output_dict = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM) + if self.secondary_module is not None: + if self.secondary_out_packer is None: + raise RuntimeError("secondary_out_packer is unexpectedly None") + secondary_tensor = self.secondary_module.wrap_module(wrapper)( + input_tensor, + labels=args.labels, + ) + secondary_dict = self.secondary_out_packer.unpack( + secondary_tensor, axis=self.CHANNEL_DIM + ) + for name in self._config.secondary_residual_names: + output_dict[name] = output_dict[name] + secondary_dict[name] + for name in self._config.secondary_out_names: + if name not in self._config.secondary_residual_names: + output_dict[name] = secondary_dict[name] secondary_output_dict = self.secondary_decoder.wrap_module(wrapper)( output_tensor.detach() # detach avoids changing base outputs ) @@ -385,6 +471,8 @@ def get_state(self): "module": self.module.get_state(), "secondary_decoder": self.secondary_decoder.get_module_state(), } + if self.secondary_module is not None: + state["secondary_module"] = self.secondary_module.get_state() return state def load_state(self, state: dict[str, Any]) -> None: @@ -399,6 +487,8 @@ def load_state(self, state: dict[str, Any]) -> None: # for backwards compatibility with old checkpoints del module["module.device_buffer"] self.module.load_state(module) + if "secondary_module" in state and self.secondary_module is not None: + self.secondary_module.load_state(state["secondary_module"]) if "secondary_decoder" in state: self.secondary_decoder.load_module_state(state["secondary_decoder"]) diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index 5d98df1cf..8b281265f 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -374,6 +374,41 @@ def get_fcn3_selector( ) +def get_single_module_with_secondary_selector( + dir: pathlib.Path | None = None, +) -> StepSelector: + normalization = get_network_and_loss_normalization_config( + names=[ + "forcing_a", + "prog_a", + "prog_b", + "diag_a", + ], + dir=dir, + ) + return StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing_a", "prog_a", "prog_b"], + out_names=["prog_a", "prog_b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["prog_a", "diag_a"], + secondary_residual_names=["prog_a"], + ), + ), + ) + + def get_multi_call_selector( dir: pathlib.Path | None = None, ) -> StepSelector: @@ -403,6 +438,7 @@ def get_multi_call_selector( get_separate_radiation_selector, get_single_module_selector, get_single_module_noise_conditioned_selector, + get_single_module_with_secondary_selector, get_multi_call_selector, ] @@ -666,6 +702,214 @@ def test_input_output_names_secondary_decoder_conflict(conflict: str): assert f"secondary_diagnostic_name is an {conflict} variable:" in str(err.value) +def test_secondary_builder_none_with_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="secondary_out_names must be empty"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_out_names=["c"], + ) + + +def test_secondary_builder_none_with_residual_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="secondary_residual_names must be empty"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_residual_names=["b"], + ) + + +def test_secondary_builder_with_empty_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="secondary_out_names must not be empty"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=[], + ) + + +def test_secondary_residual_name_not_in_secondary_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises(ValueError, match="secondary_residual_name 'b'.*secondary_out"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["c"], + secondary_residual_names=["b"], + ) + + +def test_secondary_residual_name_not_in_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises(ValueError, match="secondary_residual_name 'c'.*out_names"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["c"], + secondary_residual_names=["c"], + ) + + +def test_secondary_out_name_overlaps_out_name_without_residual_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises(ValueError, match="secondary_residual_names"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["b", "c"], + secondary_residual_names=[], + ) + + +def test_secondary_out_name_overlaps_secondary_decoder_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises(ValueError, match="secondary_diagnostic_name is an output"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["c"], + secondary_decoder=SecondaryDecoderConfig( + secondary_diagnostic_names=["c"], + network=ModuleSelector(type="MLP", config={}), + ), + ) + + +@pytest.mark.parallel +def test_secondary_network_residual_adds_to_backbone(): + """Test that secondary_residual_names outputs are added to backbone outputs.""" + torch.manual_seed(0) + normalization = get_network_and_loss_normalization_config( + names=["forcing", "prog", "diag"], + ) + config = StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing", "prog"], + out_names=["prog"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["prog", "diag"], + secondary_residual_names=["prog"], + ), + ), + ) + img_shape = DEFAULT_IMG_SHAPE + step = get_step(config, img_shape) + assert "prog" in step.output_names + assert "diag" in step.output_names + assert "prog" in step.prognostic_names + input_data = get_tensor_dict(step.input_names, img_shape, n_samples=2) + next_step_input_data = get_tensor_dict( + step.next_step_input_names, img_shape, n_samples=2 + ) + output = step.step( + args=StepArgs( + input=input_data, next_step_input_data=next_step_input_data, labels=None + ), + ) + assert "prog" in output + assert "diag" in output + assert output["prog"].shape == (2, *img_shape) + assert output["diag"].shape == (2, *img_shape) + + +@pytest.mark.parallel +def test_secondary_network_state_round_trip(): + """Test get_state/load_state with secondary module.""" + torch.manual_seed(0) + normalization = get_network_and_loss_normalization_config( + names=["forcing", "prog", "diag"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing", "prog"], + out_names=["prog"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["prog", "diag"], + secondary_residual_names=["prog"], + ) + img_shape = DEFAULT_IMG_SHAPE + step1 = get_step( + StepSelector(type="single_module", config=dataclasses.asdict(config)), + img_shape, + ) + state = step1.get_state() + assert "secondary_module" in state + + step2 = get_step( + StepSelector(type="single_module", config=dataclasses.asdict(config)), + img_shape, + ) + step2.load_state(state) + + input_data = get_tensor_dict(step1.input_names, img_shape, n_samples=1) + next_step_input_data = get_tensor_dict( + step1.next_step_input_names, img_shape, n_samples=1 + ) + args = StepArgs( + input=input_data, next_step_input_data=next_step_input_data, labels=None + ) + out1 = step1.step(args=args) + out2 = step2.step(args=args) + for name in out1: + torch.testing.assert_close(out1[name], out2[name]) + + def test_step_with_prescribed_prognostic_overwrites_output(): normalization = get_network_and_loss_normalization_config( names=["forcing_shared", "forcing_rad", "diagnostic_main", "diagnostic_rad"], From 473a53dcfd4bc1b9d93c4044480504577c7bab56 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 16:15:56 +0000 Subject: [PATCH 08/31] Add AnkurLocalNet and LocalNet to the module registry Register AnkurLocalNet (simple context-wrapped MLP) and LocalNet (noise-conditioned local neural operator) as ModuleSelector builders. LocalNet uses a NoiseConditionedModule wrapper with gaussian noise and optional positional/label conditioning, and is added to CONDITIONAL_BUILDERS to support label-conditioned training. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/__init__.py | 3 +- fme/ace/registry/local_net.py | 282 +++++++++++++++++++++++++++++ fme/ace/registry/test_local_net.py | 197 ++++++++++++++++++++ fme/core/registry/module.py | 1 + 4 files changed, 482 insertions(+), 1 deletion(-) create mode 100644 fme/ace/registry/local_net.py create mode 100644 fme/ace/registry/test_local_net.py diff --git a/fme/ace/registry/__init__.py b/fme/ace/registry/__init__.py index 739d6e4bc..814a928db 100644 --- a/fme/ace/registry/__init__.py +++ b/fme/ace/registry/__init__.py @@ -2,10 +2,11 @@ from fme.core.models import mlp as _mlp from . import land_net as _landnet +from . import local_net as _localnet from . import m2lines as _m2lines from . import prebuilt as _prebuilt from . import sfno as _sfno from . import stochastic_sfno as _sfno_crps from .registry import ModuleSelector -del _prebuilt, _sfno, _m2lines, _landnet, _sfno_crps, _mlp +del _prebuilt, _sfno, _m2lines, _landnet, _localnet, _sfno_crps, _mlp diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py new file mode 100644 index 000000000..cf0ca9472 --- /dev/null +++ b/fme/ace/registry/local_net.py @@ -0,0 +1,282 @@ +import dataclasses +from typing import Literal + +import torch +from torch import nn + +from fme.ace.registry.registry import ModuleConfig, ModuleSelector +from fme.core.dataset_info import DatasetInfo +from fme.core.distributed import Distributed +from fme.core.models.conditional_sfno.ankur import ( + AnkurLocalNetConfig, + get_lat_lon_ankur_localnet, +) +from fme.core.models.conditional_sfno.layers import Context, ContextConfig +from fme.core.models.conditional_sfno.localnet import ( + BlockType, + LocalNetConfig, + get_lat_lon_localnet, +) + + +class _ContextWrappedModule(nn.Module): + """Wraps a module that takes (x, context: Context) to accept (x, labels=None). + + This adapts the conditional_sfno forward signature to the interface + expected by the Module registry wrapper. + """ + + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + + def forward( + self, x: torch.Tensor, labels: torch.Tensor | None = None + ) -> torch.Tensor: + context = Context( + embedding_scalar=None, + embedding_pos=None, + labels=labels, + noise=None, + ) + return self.module(x, context) + + +class NoiseConditionedModule(nn.Module): + """Wraps a context-based module with gaussian noise conditioning. + + Generates gaussian noise and optional positional embeddings (with + label-position interaction), then calls the wrapped module with a + fully populated Context. + + Args: + module: An nn.Module with forward signature (x, context: Context). + img_shape: Global spatial dimensions (lat, lon) of the input data. + embed_dim_noise: Dimension of gaussian noise channels. + embed_dim_pos: Dimension of learned positional embedding. 0 disables. + embed_dim_labels: Dimension of label embeddings. 0 disables. + """ + + def __init__( + self, + module: nn.Module, + img_shape: tuple[int, int], + embed_dim_noise: int = 256, + embed_dim_pos: int = 0, + embed_dim_labels: int = 0, + ): + super().__init__() + self.module = module + self.embed_dim_noise = embed_dim_noise + self.img_shape = img_shape + self.label_pos_embed: nn.Parameter | None = None + if embed_dim_pos != 0: + self.pos_embed: nn.Parameter | None = nn.Parameter( + torch.zeros( + 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True + ) + ) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + if embed_dim_labels > 0: + self.label_pos_embed = nn.Parameter( + torch.zeros( + embed_dim_labels, + embed_dim_pos, + img_shape[0], + img_shape[1], + requires_grad=True, + ) + ) + nn.init.trunc_normal_(self.label_pos_embed, std=0.02) + else: + self.pos_embed = None + + def forward( + self, x: torch.Tensor, labels: torch.Tensor | None = None + ) -> torch.Tensor: + x = x.reshape(-1, *x.shape[-3:]) + noise = torch.randn( + [x.shape[0], self.embed_dim_noise, *x.shape[-2:]], + device=x.device, + dtype=x.dtype, + ) + + h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) + + embedding_pos: torch.Tensor | None = None + if self.pos_embed is not None: + pos_local = self.pos_embed[..., h_slice, w_slice] + embedding_pos = pos_local.repeat(x.shape[0], 1, 1, 1) + if self.label_pos_embed is not None and labels is not None: + label_local = self.label_pos_embed[..., h_slice, w_slice] + label_embedding_pos = torch.einsum( + "bl, lpxy -> bpxy", labels, label_local + ) + embedding_pos = embedding_pos + label_embedding_pos + + return self.module( + x, + Context( + embedding_scalar=None, + embedding_pos=embedding_pos, + labels=labels, + noise=noise, + ), + ) + + +@ModuleSelector.register("AnkurLocalNet") +@dataclasses.dataclass +class AnkurLocalNetBuilder(ModuleConfig): + """Configuration for the AnkurLocalNet architecture. + + A simple 3-hidden-layer MLP that optionally uses a DISCO convolution + for the first layer and a learned positional embedding. + + Attributes: + embed_dim: Dimension of the hidden layers. + use_disco_encoder: Whether to use a DISCO convolution for the first + layer instead of a 1x1 convolution. + disco_kernel_size: Kernel size for the DISCO convolution if used. + pos_embed: Whether to add a learned positional embedding after the + first layer. + activation_function: Activation function name ('relu', 'gelu', 'silu'). + data_grid: Grid type for spherical harmonic transforms used by + DISCO convolutions. + """ + + embed_dim: int = 256 + use_disco_encoder: bool = False + disco_kernel_size: int = 3 + pos_embed: bool = False + activation_function: str = "gelu" + data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" + + def build( + self, + n_in_channels: int, + n_out_channels: int, + dataset_info: DatasetInfo, + ) -> nn.Module: + params = AnkurLocalNetConfig( + embed_dim=self.embed_dim, + use_disco_encoder=self.use_disco_encoder, + disco_kernel_size=self.disco_kernel_size, + pos_embed=self.pos_embed, + activation_function=self.activation_function, + ) + context_config = ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=0, + embed_dim_labels=len(dataset_info.all_labels), + embed_dim_pos=0, + ) + net = get_lat_lon_ankur_localnet( + params=params, + in_chans=n_in_channels, + out_chans=n_out_channels, + img_shape=dataset_info.img_shape, + data_grid=self.data_grid, + context_config=context_config, + ) + return _ContextWrappedModule(net) + + +@ModuleSelector.register("LocalNet") +@dataclasses.dataclass +class LocalNetBuilder(ModuleConfig): + """Configuration for the LocalNet architecture. + + A noise-conditioned local neural operator network using DISCO convolutions + and/or 1x1 convolutions, with encoder/decoder structure and optional skip + connections. Supports label conditioning when used with conditional=True + on the ModuleSelector. + + Attributes: + embed_dim: Dimension of the embeddings. + noise_embed_dim: Dimension of the gaussian noise conditioning channels. + context_pos_embed_dim: Dimension of the learned positional embedding + used for conditioning. 0 disables. + block_types: List of filter types for each block ('disco', 'conv1x1'). + The length determines the number of blocks. + global_layer_norm: Whether to reduce along the spatial domain when + applying layer normalization. + use_mlp: Whether to use an MLP in each block. + mlp_ratio: Ratio of MLP hidden dimension to the embedding dimension. + activation_function: Activation function name ('relu', 'gelu', 'silu'). + encoder_layers: Number of convolutional layers in the encoder/decoder. + pos_embed: Whether to use a learned positional embedding inside the + LocalNet (distinct from context_pos_embed_dim). + big_skip: Whether to use a big skip connection from input to decoder. + normalize_big_skip: Whether to normalize the big skip connection. + affine_norms: Whether to use element-wise affine parameters in the + normalization layers. + lora_rank: Rank of LoRA adaptations. 0 disables LoRA. + lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank + if None. + data_grid: Grid type for spherical harmonic transforms used by + DISCO convolutions. + """ + + embed_dim: int = 256 + noise_embed_dim: int = 256 + context_pos_embed_dim: int = 0 + block_types: list[BlockType] = dataclasses.field( + default_factory=lambda: ["disco"] * 12 + ) + global_layer_norm: bool = False + use_mlp: bool = True + mlp_ratio: float = 2.0 + activation_function: str = "gelu" + encoder_layers: int = 1 + pos_embed: bool = True + big_skip: bool = True + normalize_big_skip: bool = False + affine_norms: bool = False + lora_rank: int = 0 + lora_alpha: float | None = None + data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" + + def build( + self, + n_in_channels: int, + n_out_channels: int, + dataset_info: DatasetInfo, + ) -> nn.Module: + params = LocalNetConfig( + embed_dim=self.embed_dim, + block_types=self.block_types, + global_layer_norm=self.global_layer_norm, + use_mlp=self.use_mlp, + mlp_ratio=self.mlp_ratio, + activation_function=self.activation_function, + encoder_layers=self.encoder_layers, + pos_embed=self.pos_embed, + big_skip=self.big_skip, + normalize_big_skip=self.normalize_big_skip, + affine_norms=self.affine_norms, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + embed_dim_labels = len(dataset_info.all_labels) + context_config = ContextConfig( + embed_dim_scalar=0, + embed_dim_noise=self.noise_embed_dim, + embed_dim_labels=embed_dim_labels, + embed_dim_pos=self.context_pos_embed_dim, + ) + net = get_lat_lon_localnet( + params=params, + in_chans=n_in_channels, + out_chans=n_out_channels, + img_shape=dataset_info.img_shape, + data_grid=self.data_grid, + context_config=context_config, + ) + return NoiseConditionedModule( + net, + img_shape=dataset_info.img_shape, + embed_dim_noise=self.noise_embed_dim, + embed_dim_pos=self.context_pos_embed_dim, + embed_dim_labels=embed_dim_labels, + ) diff --git a/fme/ace/registry/test_local_net.py b/fme/ace/registry/test_local_net.py new file mode 100644 index 000000000..7dfae35d2 --- /dev/null +++ b/fme/ace/registry/test_local_net.py @@ -0,0 +1,197 @@ +import dataclasses + +import pytest +import torch +from torch import nn + +import fme +from fme.ace.registry.local_net import AnkurLocalNetBuilder, LocalNetBuilder +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates +from fme.core.dataset_info import DatasetInfo +from fme.core.labels import BatchLabels +from fme.core.models.conditional_sfno.localnet import BlockType +from fme.core.registry import ModuleSelector + +IMG_SHAPE = (9, 18) + + +def _get_dataset_info(all_labels: set[str] | None = None) -> DatasetInfo: + device = fme.get_device() + return DatasetInfo( + horizontal_coordinates=LatLonCoordinates( + lat=torch.zeros(IMG_SHAPE[0], device=device), + lon=torch.zeros(IMG_SHAPE[1], device=device), + ), + vertical_coordinate=HybridSigmaPressureCoordinate( + ak=torch.arange(7, device=device), + bk=torch.arange(7, device=device), + ), + all_labels=all_labels, + ) + + +def test_ankur_local_net_is_registered(): + assert "AnkurLocalNet" in ModuleSelector.get_available_types() + + +def test_local_net_is_registered(): + assert "LocalNet" in ModuleSelector.get_available_types() + + +@pytest.mark.parametrize("use_disco_encoder", [True, False]) +@pytest.mark.parametrize("pos_embed", [True, False]) +def test_ankur_local_net_build_and_forward(use_disco_encoder: bool, pos_embed: bool): + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = AnkurLocalNetBuilder( + embed_dim=16, + use_disco_encoder=use_disco_encoder, + pos_embed=pos_embed, + ) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + assert out.shape == (2, n_out, *IMG_SHAPE) + + +def test_ankur_local_net_via_selector(): + selector = ModuleSelector( + type="AnkurLocalNet", + config=dataclasses.asdict(AnkurLocalNetBuilder(embed_dim=16)), + ) + dataset_info = _get_dataset_info() + module = selector.build( + n_in_channels=3, n_out_channels=2, dataset_info=dataset_info + ) + module = module.to(fme.get_device()) + x = torch.randn(2, 3, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + assert out.shape == (2, 2, *IMG_SHAPE) + + +@pytest.mark.parametrize( + "block_types", + [ + ["disco", "disco"], + ["conv1x1", "conv1x1"], + ["disco", "conv1x1"], + ], +) +def test_local_net_build_and_forward(block_types: list[BlockType]): + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = LocalNetBuilder( + embed_dim=16, + block_types=block_types, + ) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + assert out.shape == (2, n_out, *IMG_SHAPE) + + +def test_local_net_via_selector(): + selector = ModuleSelector( + type="LocalNet", + config=dataclasses.asdict( + LocalNetBuilder(embed_dim=16, block_types=["disco", "disco"]) + ), + ) + dataset_info = _get_dataset_info() + module = selector.build( + n_in_channels=3, n_out_channels=2, dataset_info=dataset_info + ) + module = module.to(fme.get_device()) + x = torch.randn(2, 3, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + assert out.shape == (2, 2, *IMG_SHAPE) + + +def test_ankur_local_net_backward(): + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = AnkurLocalNetBuilder(embed_dim=16) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + out.sum().backward() + for name, param in module.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + +def test_local_net_backward(): + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = LocalNetBuilder(embed_dim=16, block_types=["disco", "disco"]) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + out.sum().backward() + for name, param in module.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + +def test_local_net_noise_produces_stochastic_output(): + """Noise conditioning should produce different outputs after training.""" + torch.manual_seed(0) + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = LocalNetBuilder( + embed_dim=16, + noise_embed_dim=8, + block_types=["disco", "disco"], + affine_norms=True, + ) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + # At init, noise scale/bias weights are zero so noise has no effect. + # Set them to nonzero to verify the noise path works. + for name, param in module.named_parameters(): + if "W_scale_2d" in name or "W_bias_2d" in name: + nn.init.normal_(param, std=1.0) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + with torch.no_grad(): + out1 = module(x) + out2 = module(x) + assert not torch.allclose(out1, out2) + + +def test_local_net_with_context_pos_embed(): + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = LocalNetBuilder( + embed_dim=16, + noise_embed_dim=8, + context_pos_embed_dim=4, + block_types=["disco", "disco"], + ) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + assert out.shape == (2, n_out, *IMG_SHAPE) + + +def test_local_net_conditional_with_labels(): + """LocalNet with conditional=True should accept and use labels.""" + n_in, n_out = 3, 2 + all_labels = {"label_a", "label_b"} + dataset_info = _get_dataset_info(all_labels=all_labels) + selector = ModuleSelector( + type="LocalNet", + conditional=True, + config=dataclasses.asdict( + LocalNetBuilder( + embed_dim=16, + noise_embed_dim=8, + block_types=["disco", "disco"], + affine_norms=True, + ) + ), + ) + module = selector.build( + n_in_channels=n_in, n_out_channels=n_out, dataset_info=dataset_info + ) + module = module.to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + labels = BatchLabels.new_from_set(all_labels, n_samples=2, device=fme.get_device()) + out = module(x, labels=labels) + assert out.shape == (2, n_out, *IMG_SHAPE) diff --git a/fme/core/registry/module.py b/fme/core/registry/module.py index 5b229a570..8064db886 100644 --- a/fme/core/registry/module.py +++ b/fme/core/registry/module.py @@ -60,6 +60,7 @@ def from_state(cls, state: Mapping[str, Any]) -> "ModuleConfig": CONDITIONAL_BUILDERS = [ "NoiseConditionedSFNO", + "LocalNet", ] From 703f13857f513d52248698e7892c641a555ee8c3 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 16:17:38 +0000 Subject: [PATCH 09/31] Use training step instead of parameter surgery in noise stochasticity test Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/test_local_net.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fme/ace/registry/test_local_net.py b/fme/ace/registry/test_local_net.py index 7dfae35d2..c3cf78191 100644 --- a/fme/ace/registry/test_local_net.py +++ b/fme/ace/registry/test_local_net.py @@ -2,7 +2,6 @@ import pytest import torch -from torch import nn import fme from fme.ace.registry.local_net import AnkurLocalNetBuilder, LocalNetBuilder @@ -143,12 +142,13 @@ def test_local_net_noise_produces_stochastic_output(): affine_norms=True, ) module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) - # At init, noise scale/bias weights are zero so noise has no effect. - # Set them to nonzero to verify the noise path works. - for name, param in module.named_parameters(): - if "W_scale_2d" in name or "W_bias_2d" in name: - nn.init.normal_(param, std=1.0) x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + # At init, noise scale/bias weights are zero so noise has no effect. + # A training step makes them nonzero, enabling stochastic output. + loss = module(x).sum() + loss.backward() + optimizer = torch.optim.SGD(module.parameters(), lr=1.0) + optimizer.step() with torch.no_grad(): out1 = module(x) out2 = module(x) From 48c7c5be1bb6713fe3496f2e496875d42bab4f64 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 17:04:56 +0000 Subject: [PATCH 10/31] Extract shared NoiseConditionedModule base class Move the noise conditioning, positional embedding, and label-position interaction logic into a shared NoiseConditionedModule in noise_conditioned.py. NoiseConditionedSFNO now subclasses it, overriding _generate_noise to support isotropic noise via SHT. LocalNet uses the base class directly with gaussian noise. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 85 +-------------------- fme/ace/registry/noise_conditioned.py | 97 ++++++++++++++++++++++++ fme/ace/registry/stochastic_sfno.py | 89 ++++++---------------- fme/ace/registry/test_stochastic_sfno.py | 7 +- 4 files changed, 122 insertions(+), 156 deletions(-) create mode 100644 fme/ace/registry/noise_conditioned.py diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index cf0ca9472..130f015d2 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -4,9 +4,9 @@ import torch from torch import nn +from fme.ace.registry.noise_conditioned import NoiseConditionedModule from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo -from fme.core.distributed import Distributed from fme.core.models.conditional_sfno.ankur import ( AnkurLocalNetConfig, get_lat_lon_ankur_localnet, @@ -42,89 +42,6 @@ def forward( return self.module(x, context) -class NoiseConditionedModule(nn.Module): - """Wraps a context-based module with gaussian noise conditioning. - - Generates gaussian noise and optional positional embeddings (with - label-position interaction), then calls the wrapped module with a - fully populated Context. - - Args: - module: An nn.Module with forward signature (x, context: Context). - img_shape: Global spatial dimensions (lat, lon) of the input data. - embed_dim_noise: Dimension of gaussian noise channels. - embed_dim_pos: Dimension of learned positional embedding. 0 disables. - embed_dim_labels: Dimension of label embeddings. 0 disables. - """ - - def __init__( - self, - module: nn.Module, - img_shape: tuple[int, int], - embed_dim_noise: int = 256, - embed_dim_pos: int = 0, - embed_dim_labels: int = 0, - ): - super().__init__() - self.module = module - self.embed_dim_noise = embed_dim_noise - self.img_shape = img_shape - self.label_pos_embed: nn.Parameter | None = None - if embed_dim_pos != 0: - self.pos_embed: nn.Parameter | None = nn.Parameter( - torch.zeros( - 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True - ) - ) - nn.init.trunc_normal_(self.pos_embed, std=0.02) - if embed_dim_labels > 0: - self.label_pos_embed = nn.Parameter( - torch.zeros( - embed_dim_labels, - embed_dim_pos, - img_shape[0], - img_shape[1], - requires_grad=True, - ) - ) - nn.init.trunc_normal_(self.label_pos_embed, std=0.02) - else: - self.pos_embed = None - - def forward( - self, x: torch.Tensor, labels: torch.Tensor | None = None - ) -> torch.Tensor: - x = x.reshape(-1, *x.shape[-3:]) - noise = torch.randn( - [x.shape[0], self.embed_dim_noise, *x.shape[-2:]], - device=x.device, - dtype=x.dtype, - ) - - h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) - - embedding_pos: torch.Tensor | None = None - if self.pos_embed is not None: - pos_local = self.pos_embed[..., h_slice, w_slice] - embedding_pos = pos_local.repeat(x.shape[0], 1, 1, 1) - if self.label_pos_embed is not None and labels is not None: - label_local = self.label_pos_embed[..., h_slice, w_slice] - label_embedding_pos = torch.einsum( - "bl, lpxy -> bpxy", labels, label_local - ) - embedding_pos = embedding_pos + label_embedding_pos - - return self.module( - x, - Context( - embedding_scalar=None, - embedding_pos=embedding_pos, - labels=labels, - noise=noise, - ), - ) - - @ModuleSelector.register("AnkurLocalNet") @dataclasses.dataclass class AnkurLocalNetBuilder(ModuleConfig): diff --git a/fme/ace/registry/noise_conditioned.py b/fme/ace/registry/noise_conditioned.py new file mode 100644 index 000000000..420077e02 --- /dev/null +++ b/fme/ace/registry/noise_conditioned.py @@ -0,0 +1,97 @@ +"""Noise-conditioned wrapper for modules with a Context-based forward signature.""" + +import torch +from torch import nn + +from fme.core.distributed import Distributed +from fme.core.models.conditional_sfno.layers import Context + + +class NoiseConditionedModule(nn.Module): + """Wraps a context-based module with noise conditioning. + + Generates noise (gaussian by default, or via a custom factory) and + optional positional embeddings (with label-position interaction), + then calls the wrapped module with a fully populated Context. + + Args: + module: An nn.Module with forward signature (x, context: Context). + img_shape: Global spatial dimensions (lat, lon) of the input data. + embed_dim_noise: Dimension of noise channels. + embed_dim_pos: Dimension of learned positional embedding. 0 disables. + embed_dim_labels: Dimension of label embeddings. 0 disables. + noise_factory: Optional callable (batch_size, device, dtype) -> Tensor + that produces noise of shape [batch_size, embed_dim_noise, H, W]. + If None, gaussian noise is generated. + """ + + def __init__( + self, + module: nn.Module, + img_shape: tuple[int, int], + embed_dim_noise: int = 256, + embed_dim_pos: int = 0, + embed_dim_labels: int = 0, + ): + super().__init__() + self.module = module + self.embed_dim_noise = embed_dim_noise + self.img_shape = img_shape + self.label_pos_embed: nn.Parameter | None = None + if embed_dim_pos != 0: + self.pos_embed: nn.Parameter | None = nn.Parameter( + torch.zeros( + 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True + ) + ) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + if embed_dim_labels > 0: + self.label_pos_embed = nn.Parameter( + torch.zeros( + embed_dim_labels, + embed_dim_pos, + img_shape[0], + img_shape[1], + requires_grad=True, + ) + ) + nn.init.trunc_normal_(self.label_pos_embed, std=0.02) + else: + self.pos_embed = None + + def _generate_noise(self, x: torch.Tensor) -> torch.Tensor: + """Generate noise. Subclasses can override for custom noise types.""" + return torch.randn( + [x.shape[0], self.embed_dim_noise, *x.shape[-2:]], + device=x.device, + dtype=x.dtype, + ) + + def forward( + self, x: torch.Tensor, labels: torch.Tensor | None = None + ) -> torch.Tensor: + x = x.reshape(-1, *x.shape[-3:]) + noise = self._generate_noise(x) + + h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) + + embedding_pos: torch.Tensor | None = None + if self.pos_embed is not None: + pos_local = self.pos_embed[..., h_slice, w_slice] + embedding_pos = pos_local.repeat(x.shape[0], 1, 1, 1) + if self.label_pos_embed is not None and labels is not None: + label_local = self.label_pos_embed[..., h_slice, w_slice] + label_embedding_pos = torch.einsum( + "bl, lpxy -> bpxy", labels, label_local + ) + embedding_pos = embedding_pos + label_embedding_pos + + return self.module( + x, + Context( + embedding_scalar=None, + embedding_pos=embedding_pos, + labels=labels, + noise=noise, + ), + ) diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 99bb7f75f..d2db9c16c 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -5,11 +5,11 @@ import torch +from fme.ace.registry.noise_conditioned import NoiseConditionedModule from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo from fme.core.distributed.distributed import Distributed from fme.core.models.conditional_sfno.sfnonet import ( - Context, ContextConfig, SFNONetConfig, get_lat_lon_sfnonet, @@ -48,7 +48,13 @@ def isotropic_noise( return isht(alm) -class NoiseConditionedSFNO(torch.nn.Module): +class NoiseConditionedSFNO(NoiseConditionedModule): + """Noise-conditioned SFNO with support for isotropic noise. + + Extends NoiseConditionedModule with isotropic noise generation that uses + the SFNO's inverse spherical harmonic transform. + """ + def __init__( self, conditional_model: ConditionalSFNO, @@ -58,82 +64,31 @@ def __init__( embed_dim_pos: int = 0, embed_dim_labels: int = 0, ): - super().__init__() - self.conditional_model = conditional_model - self.embed_dim = embed_dim_noise + super().__init__( + module=conditional_model, + img_shape=img_shape, + embed_dim_noise=embed_dim_noise, + embed_dim_pos=embed_dim_pos, + embed_dim_labels=embed_dim_labels, + ) self.noise_type = noise_type - self.img_shape = img_shape - self.label_pos_embed: torch.nn.Parameter | None = None - # register pos embed if pos_embed_dim != 0 - if embed_dim_pos != 0: - self.pos_embed = torch.nn.Parameter( - torch.zeros( - 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True - ) - ) - # initialize pos embed with std=0.02 - torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) - if embed_dim_labels > 0: - self.label_pos_embed = torch.nn.Parameter( - torch.zeros( - embed_dim_labels, - embed_dim_pos, - img_shape[0], - img_shape[1], - requires_grad=True, - ) - ) - torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02) - else: - self.pos_embed = None - def forward( - self, x: torch.Tensor, labels: torch.Tensor | None = None - ) -> torch.Tensor: - x = x.reshape(-1, *x.shape[-3:]) + def _generate_noise(self, x: torch.Tensor) -> torch.Tensor: if self.noise_type == "isotropic": - lmax = self.conditional_model.itrans_up.lmax - mmax = self.conditional_model.itrans_up.mmax - noise = isotropic_noise( - (x.shape[0], self.embed_dim), + lmax = self.module.itrans_up.lmax + mmax = self.module.itrans_up.mmax + return isotropic_noise( + (x.shape[0], self.embed_dim_noise), lmax, mmax, - self.conditional_model.itrans_up, + self.module.itrans_up, device=x.device, ) elif self.noise_type == "gaussian": - noise = torch.randn( - [x.shape[0], self.embed_dim, *x.shape[-2:]], - device=x.device, - dtype=x.dtype, - ) + return super()._generate_noise(x) else: raise ValueError(f"Invalid noise type: {self.noise_type}") - h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) - - if self.pos_embed is not None: - pos_local = self.pos_embed[..., h_slice, w_slice] - embedding_pos = pos_local.repeat(noise.shape[0], 1, 1, 1) - if self.label_pos_embed is not None and labels is not None: - label_local = self.label_pos_embed[..., h_slice, w_slice] - label_embedding_pos = torch.einsum( - "bl, lpxy -> bpxy", labels, label_local - ) - embedding_pos = embedding_pos + label_embedding_pos - else: - embedding_pos = None - - return self.conditional_model( - x, - Context( - embedding_scalar=None, - embedding_pos=embedding_pos, - labels=labels, - noise=noise, - ), - ) - # this is based on the call signature of SphericalFourierNeuralOperatorNet at # https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501 diff --git a/fme/ace/registry/test_stochastic_sfno.py b/fme/ace/registry/test_stochastic_sfno.py index 191cc72d7..453b1c761 100644 --- a/fme/ace/registry/test_stochastic_sfno.py +++ b/fme/ace/registry/test_stochastic_sfno.py @@ -4,12 +4,9 @@ import torch from torch_harmonics import InverseRealSHT -from fme.ace.registry.stochastic_sfno import ( - Context, - NoiseConditionedSFNO, - isotropic_noise, -) +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO, isotropic_noise from fme.core.device import get_device +from fme.core.models.conditional_sfno.layers import Context @pytest.mark.parametrize("nlat, nlon", [(8, 16), (64, 128)]) From 34abe3ad187670b3ce292e3024103b6ef5fd3cc8 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 17:07:27 +0000 Subject: [PATCH 11/31] Use composition instead of inheritance for noise conditioning Replace NoiseConditionedSFNO subclass with a factory function that passes a noise_generator callable to NoiseConditionedModule. The base class accepts a NoiseGenerator callback (defaults to gaussian_noise), and the SFNO factory creates an isotropic noise generator that captures the model's inverse SHT. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/noise_conditioned.py | 39 ++++++++------ fme/ace/registry/stochastic_sfno.py | 77 +++++++++++++++------------ 2 files changed, 68 insertions(+), 48 deletions(-) diff --git a/fme/ace/registry/noise_conditioned.py b/fme/ace/registry/noise_conditioned.py index 420077e02..8bdc37718 100644 --- a/fme/ace/registry/noise_conditioned.py +++ b/fme/ace/registry/noise_conditioned.py @@ -1,18 +1,34 @@ """Noise-conditioned wrapper for modules with a Context-based forward signature.""" +from collections.abc import Callable + import torch from torch import nn from fme.core.distributed import Distributed from fme.core.models.conditional_sfno.layers import Context +NoiseGenerator = Callable[[torch.Tensor, int], torch.Tensor] +"""Callable (x, embed_dim_noise) -> noise_tensor. + +Takes the input tensor (for shape/device/dtype) and noise embedding dimension, +returns noise of shape [x.shape[0], embed_dim_noise, *x.shape[-2:]]. +""" + + +def gaussian_noise(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: + return torch.randn( + [x.shape[0], embed_dim_noise, *x.shape[-2:]], + device=x.device, + dtype=x.dtype, + ) + class NoiseConditionedModule(nn.Module): """Wraps a context-based module with noise conditioning. - Generates noise (gaussian by default, or via a custom factory) and - optional positional embeddings (with label-position interaction), - then calls the wrapped module with a fully populated Context. + Generates noise and optional positional embeddings (with label-position + interaction), then calls the wrapped module with a fully populated Context. Args: module: An nn.Module with forward signature (x, context: Context). @@ -20,9 +36,8 @@ class NoiseConditionedModule(nn.Module): embed_dim_noise: Dimension of noise channels. embed_dim_pos: Dimension of learned positional embedding. 0 disables. embed_dim_labels: Dimension of label embeddings. 0 disables. - noise_factory: Optional callable (batch_size, device, dtype) -> Tensor - that produces noise of shape [batch_size, embed_dim_noise, H, W]. - If None, gaussian noise is generated. + noise_generator: Callable that produces noise given the input tensor + and noise embedding dimension. Defaults to gaussian noise. """ def __init__( @@ -32,11 +47,13 @@ def __init__( embed_dim_noise: int = 256, embed_dim_pos: int = 0, embed_dim_labels: int = 0, + noise_generator: NoiseGenerator = gaussian_noise, ): super().__init__() self.module = module self.embed_dim_noise = embed_dim_noise self.img_shape = img_shape + self._noise_generator = noise_generator self.label_pos_embed: nn.Parameter | None = None if embed_dim_pos != 0: self.pos_embed: nn.Parameter | None = nn.Parameter( @@ -59,19 +76,11 @@ def __init__( else: self.pos_embed = None - def _generate_noise(self, x: torch.Tensor) -> torch.Tensor: - """Generate noise. Subclasses can override for custom noise types.""" - return torch.randn( - [x.shape[0], self.embed_dim_noise, *x.shape[-2:]], - device=x.device, - dtype=x.dtype, - ) - def forward( self, x: torch.Tensor, labels: torch.Tensor | None = None ) -> torch.Tensor: x = x.reshape(-1, *x.shape[-3:]) - noise = self._generate_noise(x) + noise = self._noise_generator(x, self.embed_dim_noise) h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index d2db9c16c..473cd93e0 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -5,7 +5,11 @@ import torch -from fme.ace.registry.noise_conditioned import NoiseConditionedModule +from fme.ace.registry.noise_conditioned import ( + NoiseConditionedModule, + NoiseGenerator, + gaussian_noise, +) from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo from fme.core.distributed.distributed import Distributed @@ -48,46 +52,53 @@ def isotropic_noise( return isht(alm) -class NoiseConditionedSFNO(NoiseConditionedModule): - """Noise-conditioned SFNO with support for isotropic noise. +def _make_sfno_noise_generator( + noise_type: Literal["isotropic", "gaussian"], + conditional_model: ConditionalSFNO, +) -> NoiseGenerator: + """Create a noise generator for an SFNO model. - Extends NoiseConditionedModule with isotropic noise generation that uses - the SFNO's inverse spherical harmonic transform. + For gaussian noise, returns the default generator. For isotropic noise, + returns a generator that uses the SFNO's inverse spherical harmonic + transform. """ + if noise_type == "gaussian": + return gaussian_noise + elif noise_type == "isotropic": - def __init__( - self, - conditional_model: ConditionalSFNO, - img_shape: tuple[int, int], - noise_type: Literal["isotropic", "gaussian"] = "gaussian", - embed_dim_noise: int = 256, - embed_dim_pos: int = 0, - embed_dim_labels: int = 0, - ): - super().__init__( - module=conditional_model, - img_shape=img_shape, - embed_dim_noise=embed_dim_noise, - embed_dim_pos=embed_dim_pos, - embed_dim_labels=embed_dim_labels, - ) - self.noise_type = noise_type - - def _generate_noise(self, x: torch.Tensor) -> torch.Tensor: - if self.noise_type == "isotropic": - lmax = self.module.itrans_up.lmax - mmax = self.module.itrans_up.mmax + def _isotropic(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: + lmax = conditional_model.itrans_up.lmax + mmax = conditional_model.itrans_up.mmax return isotropic_noise( - (x.shape[0], self.embed_dim_noise), + (x.shape[0], embed_dim_noise), lmax, mmax, - self.module.itrans_up, + conditional_model.itrans_up, device=x.device, ) - elif self.noise_type == "gaussian": - return super()._generate_noise(x) - else: - raise ValueError(f"Invalid noise type: {self.noise_type}") + + return _isotropic + else: + raise ValueError(f"Invalid noise type: {noise_type}") + + +def NoiseConditionedSFNO( + conditional_model: ConditionalSFNO, + img_shape: tuple[int, int], + noise_type: Literal["isotropic", "gaussian"] = "gaussian", + embed_dim_noise: int = 256, + embed_dim_pos: int = 0, + embed_dim_labels: int = 0, +) -> NoiseConditionedModule: + """Create a noise-conditioned SFNO with support for isotropic noise.""" + return NoiseConditionedModule( + module=conditional_model, + img_shape=img_shape, + embed_dim_noise=embed_dim_noise, + embed_dim_pos=embed_dim_pos, + embed_dim_labels=embed_dim_labels, + noise_generator=_make_sfno_noise_generator(noise_type, conditional_model), + ) # this is based on the call signature of SphericalFourierNeuralOperatorNet at From b5e451cd3ba945fac5b06ea93b885e587011a46a Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 18:16:09 +0000 Subject: [PATCH 12/31] Add regression tests for AnkurLocalNet and LocalNet Replace the regression coverage lost when HybridNet was removed. Each test uses validate_tensor to compare output against a saved baseline, following the same pattern as the SFNO regression tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../models/conditional_sfno/test_ankur.py | 40 ++++++++++++++++ .../models/conditional_sfno/test_localnet.py | 43 ++++++++++++++++++ .../testdata/test_ankur_localnet_output.pt | Bin 0 -> 9582 bytes .../testdata/test_localnet_output.pt | Bin 0 -> 9412 bytes 4 files changed, 83 insertions(+) create mode 100644 fme/core/models/conditional_sfno/testdata/test_ankur_localnet_output.pt create mode 100644 fme/core/models/conditional_sfno/testdata/test_localnet_output.pt diff --git a/fme/core/models/conditional_sfno/test_ankur.py b/fme/core/models/conditional_sfno/test_ankur.py index 58b042953..7f99fbd08 100644 --- a/fme/core/models/conditional_sfno/test_ankur.py +++ b/fme/core/models/conditional_sfno/test_ankur.py @@ -1,11 +1,16 @@ +import os + import pytest import torch from fme.core.device import get_device +from fme.core.testing.regression import validate_tensor from .ankur import AnkurLocalNetConfig, get_lat_lon_ankur_localnet from .layers import Context, ContextConfig +DIR = os.path.dirname(os.path.abspath(__file__)) + @pytest.mark.parametrize("use_disco_encoder", [True, False]) @pytest.mark.parametrize("pos_embed", [True, False]) @@ -157,3 +162,38 @@ def test_ankur_localnet_disco_kernel_size(): ) output = model(x, context) assert output.shape == (n_samples, output_channels, *img_shape) + + +def setup_ankur_localnet(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = AnkurLocalNetConfig(embed_dim=16, use_disco_encoder=True) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + # must initialize on CPU to get the same results on GPU + x = torch.randn(n_samples, input_channels, *img_shape).to(device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0).to(device), + labels=torch.randn(n_samples, 0).to(device), + noise=torch.randn(n_samples, 0, *img_shape).to(device), + embedding_pos=torch.randn(n_samples, 0, *img_shape).to(device), + ) + return model, x, context + + +def test_ankur_localnet_output_is_unchanged(): + torch.manual_seed(0) + model, x, context = setup_ankur_localnet() + with torch.no_grad(): + output = model(x, context) + validate_tensor( + output, + os.path.join(DIR, "testdata/test_ankur_localnet_output.pt"), + ) diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py index 886444d53..307b6cad2 100644 --- a/fme/core/models/conditional_sfno/test_localnet.py +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -1,12 +1,17 @@ +import os + import pytest import torch from torch import nn from fme.core.device import get_device +from fme.core.testing.regression import validate_tensor from .layers import Context, ContextConfig from .localnet import LocalNetConfig, get_lat_lon_localnet +DIR = os.path.dirname(os.path.abspath(__file__)) + @pytest.mark.parametrize( "conditional_embed_dim_scalar, conditional_embed_dim_labels, " @@ -255,3 +260,41 @@ def test_backward_pass(): loss.backward() for name, param in model.named_parameters(): assert param.grad is not None, f"No gradient for {name}" + + +def setup_localnet(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + block_types=["disco", "disco"], + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + # must initialize on CPU to get the same results on GPU + x = torch.randn(n_samples, input_channels, *img_shape).to(device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0).to(device), + labels=torch.randn(n_samples, 0).to(device), + noise=torch.randn(n_samples, 0, *img_shape).to(device), + embedding_pos=torch.randn(n_samples, 0, *img_shape).to(device), + ) + return model, x, context + + +def test_localnet_output_is_unchanged(): + torch.manual_seed(0) + model, x, context = setup_localnet() + with torch.no_grad(): + output = model(x, context) + validate_tensor( + output, + os.path.join(DIR, "testdata/test_localnet_output.pt"), + ) diff --git a/fme/core/models/conditional_sfno/testdata/test_ankur_localnet_output.pt b/fme/core/models/conditional_sfno/testdata/test_ankur_localnet_output.pt new file mode 100644 index 0000000000000000000000000000000000000000..6f2db9106ebd0314dc488a8ac5df7eedf7b2c5dd GIT binary patch literal 9582 zcmbt)30RI>+xF8u3KdE_8Vu1SGW4u-kx)n^l_8|LNVCjiB_gFnhA5p0eNAIE*K`&`$$?(1Ib8csuf9$p4RP>}II znj&Kr!^6hi!`#B@z#%ttM`tSwM<*K(bLT@IE{8nStt~t()LafY>go{2uK)LThTF=+ z+0Dvc&HRvugQL5!xtopUAqPima}OIQcV{v5-NuO7rM_))^+{2Y$$>2Sg z{(XIMeJ&4Ifs0mF&W?^YRvr${PVT~6+^lWfY^>KhSa}G$3NA2m74mQuHgXkFGt?Il zG=j$PBM$PfZul>?wveIocUn`kb#^;w;bHD+i=3aRmPmZ744tuHMekduyZmhqx5g*{V$7ibnx(SwD~224{ev`=gj}B|7#iK8Onbr!!N5=U!k$ zOo2a^9J~Rcq$em7szKWEE=Xt(;$v_Ru823HZ*c`$R+m7!un+~)its+53I^IOXq4>2 zj)pe~wPwTUS{Mv0{4wkhjwQ^iU|NM^^2aiSTQ(uBy9XlGe6+Jng!;xz;HJ}C{GL#X zxsSVHswzT9+9x0zS^#dN7xlB&gpKlr zSpANh_8SONt+TxJV8a7k^Lq{l#|hM|5TN&1)vyd}f~CnK34y=ghrwNYXs$^?lF?!ap|!ilDpb)<$f*t10pd#5d(waYlymd0}WR*uhUSa zG2y3~$4J+M|yqw-DQKjL>ydd+Oh12gr z6|$q@KQj||c6KAL@EzHsSB)$BF?gr22h&)D&YPYOnY+=*(The`?o9gIJ|P-Zqe*R@ zny|*C31Sw>*!Fr7qCy@bHDUn#eFMn7`3aW-rqBk#4A@OXgKn8b7y7GFt>dC}3vU`O zu0D$s<2GRQ06L&UsPY?yK$_To3O=57KU`mdtx1q+4T!zi?8!=U{M+IrIA z<&=vs#S**|D8@_CJg7Pp)GV);D)p2fPHd!V0`dpVAOisA9G> zIg;ch7YW<&88i5#GfT3bi=A<`kyWici8%OGlG^>DWXkz`<{~~8GtTufS>iqywlBQ# z*rt<$=6tg z$4U|~yW&CwJ}V%N3*dE#C*vkHXzjsN&hxksD7AKO%EjOloTSQ9Y-EN z-$gd_H!)v--bPwSxiOY+f`-eUi16Huvs;CszrKoS?-i%VrijtCHiGorZUz<(PbF{0LAC88gcQW7|MfDQoJfF>b1{x-N>doMpsv6f zI_VG45n724yV?=AgPXbt%%W$K22uEW98oVm;R{r#V{jghioZb<_akgf3diHoLM)a~ zM{MH)nqr)fMa!m9sjVisawreQuO1;>US7h>#MaWPU#!b3ftP_Xs>Dbr?#G!SK3a^j{dpj=C7= z81vC-<0Z(p`-o50-H6G4kBRzb?D_r@rY*N{h4&HXpDexL17&cS)`m_MKH9!YjP8i#rbU&rXWgw5b z(7t*I;%0TQeVquCt`Ha}T|&#YD-b+)72E6mAeG^R{fC2aZ&o_G6dSNUq8S^>Q&g>d ziy-3>RGytcs?8`I2g=YbA%j5@$IAV=fbff{Lpm`M4%a0y`Xh=MXZYbt-9=n$2!^8H zRj53Rz&>X&80$yk;e|x_x-#f2xyjW0ngpF+9ERymZCKykiLF<2AakV=3o>~rzra&m zXI#Y6+vibeGKS&fchS20H2YS20Y^7au|>rX!lr4+2=9kvRU6Dk)gh#t3bJMf-CwOq z!;|l$V_zsf40&Om;1pO#(^v+b^_SqFQ-Qh;0ovE= zjePwU_=X3N129tr6r~nS7Ukicf2%uh?DCQAr%vdsR@VCJ5&zw zISB|!Ho$3J1BBFR(kTPdRAOHSj#xZ{rv5`Pt>0s>(Q+ExEkYf)=c4NC0W4`B#8~HA z*dAxl#J*VQGnv?7orLV4$*7tg57nQk*cxAhnwiQpy7?3A*cl{8vj}=aQ>ocfS*lon z6Wr>vF{xjOemClbw|Xb?esMPaEGQFGBHL*KNqt2Gflr+uNE&xP|l9}@4+%P}X>6K6UMk*ec_FLDR* zpx_L+9tGliY7p{GZ4fto8xkZ>61fx?)=wb`_D+*-Vr=7phy^F0c5f@jg`Kr@TV@bX z&2bV`EQPK+9pr7CI6LR_SY~3`06aP8>Ph}nC>V@KhT#G*WfU`Qnp0WE@rMvIq=NoB zb>^!18@1l#r?5&AA7*wx_a)JL1Cg}%7BWXnac2F}j0S3kHzD`QllKAS$k$}1?81Yj zVeW47syvt#xg`GtT5W?^$HsY%Uo8JdRWOzRYrq;!K}UJ!H&%3u=92$yn!Xw&0meR$5{; zOaJa{5@nc9(##?i3~KGpzPyLC0H+=4_Q5FvO_Z}Q#f>p^-c9O zNinY?EHxX}*<~_p(_St19YcGv{#zu}&-jp4lcllKcO8~HEr!+Xb`o04LNx3iGc4;CiORc2Mn)x=`Iik>w6!Rre3k~DyUoJ2x)nHjRG-{C zb(pxgJR(VBw~4d!4`!ThQRa=0Lag$i-dYQE&52u7G~xRF0TGK%B?C?>@D&V0ylys> zNAF^;{#x)yWD=XyBs?_b_=);eU~hVhvpVlkU)T!WCCv~{s{_O6In;8qU{;n1tw#mu zZ?1)8#(OBN>&3dnR%jl5j@bh#;OTr1?fqY&wzdOOPGxu}*$9gqPK;{8IaA}~aX^@j zgb~gjT;2ycoi8{-I-$Wi148H3;+*XZc*L+UYES^Ne4o(mlN|x?R6E z+=D=|4n%F-1;>fK2;fVG!AVcfp3sIZ_Z8`BVNt4jx&=mwmtn~94Zf+*rKyJ)wBFLu>GlsTboI=k6Mx=*1yr_%Xbj!)rG~|(!m@`#TJDaIOV>^o)C`jAlHZkhi<|v zItojbE0N+LPp9}XDC0CA-DG$Zay7pAsrUgG3ST1f7U%59`iLLuwODZH8u&Q-en-J4 zFitRNR#*u{-bm5zUt<}~??It-1MUXp zqhQW6>|dP^^8@v;tLo?YIE7H1*?{GA2#Uh*k(l!a(yUq-m5;z9S(47TXVBJxSEx-Y z$NNdOcw*lTB07$=J}&zF@^?7z>q5-xdYCDd!*h87e1_Pt-}?-mkDFoU^A&pRQFt{E zVs6YkOgmc*zr97s5M`mvE)^$MW+14r8V>rskch2=HCHqqwuGSO_b8Z$JcrqxPO#SW zV&(HU2oB4`E}=xs6->e4U;!BQT-3HmmDVZFqhV)r(aG?Gx}7gJdS_z~S1C$vm11|^ z2RNVPqhnLPz)YYTrD{WP;S!|>@4n~wRs~3O=d9^)9qG3xsCJd0Lk_ld(RCTx&5578 z-lRgog%f9ue8LZXUb=ha5iEQjU<+rhw%J{X7Zjl@mv5wJ_X*Rkkyf}~x5rk!+mN7# zF+1;fc!o@&&&1AQ*?>03=F73TDhS=9VIXd2pc9;lq=ZWdsAQu^<{Wwx1gTY8AhrJd z1k>Lh#&`(ypAPVij+h?n#diq;2V>)Yvc zSfP}xTPukBdp&4MwicZ%uo4?;&SQ^J3Jw`R!A$81xcbV_L8pGC9s*LiZNYza2q(=a zQQ0OHdZlYCLN?vNxWyaH9Phx|l?*EG^#S_aub}#li|*nI!&8gfIJ#{T9iA>mcSla6 zYm6o6&aPYN*Hp!{!)_>-lBfP{i|L8JS1>y!Mk7UxaI5qX>}%7YVtWs}WafrLSw-N>`Azf9X2Vr#)-f0UE}21GV+kNs1m8djY!(wgwl2Ipr1Md z=R6_WEzQZf2=^dl;ua~%9U=1ZvM~Ryh;{v9C||>a`x7&u^jr-~=FP?m)o$YH7f3w& zcN4XDF(igh2!cv;F{gJrr1&byaPu*8Y(_k}ZJNf6Ibp<1Z-^tWPHCXaUIWhzKM~JE zz2t%FOOm}MpPcznMygytlcxjEN!Kl6-IoOEQ#oa?3H=BY@>6c z>>aJwGxI*%K=b2!@=@jnwyau<_R!jl`(LDR`Gho%(o=GP@D(?*l%DY*yiXD_g{o-a#2I^X3vm87 zg0s`3w1RBpQLB)N-0OPCTdK+)cB&(Mi|twZR4^<0@JIG&z*8uDC?Rqy69^Q+glr@k z6%}VLS8mVPwEcG|SKh%CM_ZQFeJ*xUSQaaG>NEWC(8hXoWjH_n&f25Bi+#v%GMn8g z2fa3NRz+wTvHS}7-&#i|XHU+Q5Y=Pe+suXO%BN9yrxKg=lVQ2PnOP>pU^(_uQok&P zH9d5>7O%otq8;W7oegtf!WO{Z+z-rz_wt#|Pc~;Bd*;OIGb_pT9nT7m)H$dTG3yH?*}I7-x6D9<%RC%ToB^3c zUW9XfByW9+32yk4x69`fQ$BtYGv9$6j(bM*F0>OJn=G<(tvi{wXf|2GSCJ#N>I)#4b#MM#4 zV_53*9?BLxbkiq(TIG?5okwC}zM}`X8#wTzDE4UNX#@YZUCc6!z46YEg(tI0U{)$d=jYTzabq6hEyvL_vk!umh9su-GQP?yQjufP(3??? z&eyMSmZgg%ov*kPu?xzvuVFye(6PG_ATQOZc$^YVnEC?gl08sQd5ek_!8m4@k4X)$ zQTASu#-|o=a?Ll9>HP)wj^tw^PnlY5NP*jLUMT0DODiu4(i69NskYX0SRH1eYN{L! z?vkVF^FH8>OFd@uO{ETrGPFIY6vLdee@A`=B#(2^f~Pf@P|3%WN+GJR*o-HcZBPt| zh11b%Xjpy?Q@2+kC%GSDis^9NQHbludtjK}4y9v7pgy@UALC!ShQkye4yep71vGe)@=d zyFtvb|9}R^dPEn!z~&Qr-HIWc3y2H-5qr#SR=iRf`sj3N$GmMStxX)VO9q zMQ|$BkX}M_f2vTg7(Tj0_%7Vb<{(2~2#vpO#5<0!u`4GI@mg`Hm=}we&CzHpzYPa3 zPENY59g$0uQNGU~9*O}d-w=uAifmW}@Y0};$+SY4m&!T~W96Aw*mTh!B8M(Rj-!uT zJ<(7vih{T2O%$2FKuLQf+5>X1!T&5o1N_ixk%ob7UtuH@2GiyqoZrQ%uk4qn2P|h% zE6-}IWAwst=Oe@_Pod|o55Q#cZsb2&$;roWr9qooa0-_&V}An*k~n$ApZZkYS&3S4 z{wcnQ9xUXXL*vKh(ftb9n6mLUHa&lgHM851yqoin=)^+9|2`+@zLmx<6QcvSE}^rL zhfcoCMSHnZvFPAYUj@3Hn5W&vH zCDAl^cXi^BXewrB6>{QTNh-A7hskCHKz9D?jOb6kcydXSwkTD@LvR6QteHetY*<0- z7M8=9Q^z2JNeDB20>9sC5&a<_y^-D|Z|xL>*OuW#@HqNUhvH$52rU~-#LxYjD4Dhg z0poHs!Epe!_TR9B$AH$TEumIHFL8<%K)9+8Y?TktD9(XMUpmYTGI3-@K7wzXapdd)sV8Al zV}(s+B8W^MWNv%8GILWwB{M#$iFoH*;mh{B*m(0cplOTST4GS(@s9v=dLoBoVuepA&<3X&n)kWJ%jS)4H=ekT`W~!}j_; z=%^>6rtT(Q^6o&)zON)`?+nZ`n~4o_4+#IrG_5sj4`SK)RcM;&p@;+$;cQi8IWVBE zp$k0;FOpGwmsPYUfykZCCnJ3e;4ogeV*hJ1(*LBBygn#}?|TljkL*51?)z|jD8UU7 z(OZvKOa=CgnHFsEy&kOHNd~MNt3Vz+8D%9pOpqGA-8fNv9ZLn3@HRq~wXMjAJXphp zD%S_>E^?l|W?=+IE)0<=$DEmd*&9e}%vE-$VOhpehYM_%lFv+O$4C?&-AtnEJBd+A zG>SO;uJ1-QWMtG6Vg6C3<%6Fjtk9KpGt>>w*BmBgUskcV$|VsxBahY5WelP5SH#tR zoE1Op2ha7Jm^7V{2hd`^hHuL=aWU(wsArAFGxH6R98N83zL+ z$=7xR{D>1mh+QE$SpJN(I;N8&i=s)Dcrt0#XeKhPf{=5YjAtD`$=bcm5Esz z$fF$)khj5d+YK13mcSLO8uICN7Wq}#zO#SAV&5zV<5&OJm2Cyal)tNN|2N}6^iQSj zCBnaEczM^(@v=oL7HMj6?rCW(TdK84W9gzrEB=$_P>}nrR*6&H|J4ok zxqb=!-^`UivNIHaJ@`}mU+`G|3iwA(#lk;jv`M}yG#6P`E+{%`R9@1|;~&(E*&PncOe|L5cU;vW5K|8+Iyto>I@|K}Pw@BQZ^ b(YgP)&-soV`{esYMB%(6!1y)(A7lRmn45h| literal 0 HcmV?d00001 diff --git a/fme/core/models/conditional_sfno/testdata/test_localnet_output.pt b/fme/core/models/conditional_sfno/testdata/test_localnet_output.pt new file mode 100644 index 0000000000000000000000000000000000000000..70c5cc8e2a55c2b4615b37dba280cf1755f1c8ca GIT binary patch literal 9412 zcmb7q30RKJ`*ul6g_0$tB$bLbEvjeEsidOPMyvKs+E=Nhg%Vm&$Sx#H*;9IE#A^>B zOL()dWy_xRd;GTJ{k{M9{eRy#$1%qo_uO-z*Lh#}b<908*Ximg(OXPRT3YOXl#ybR zVkzOtDSBQka|@xdt`|Ll}W?By*k z8e@L?}cE-?I1EkgcPUH|s2 z$B0S)^;_#j#3#iDrv#>jCnbxF+1d0K^A#1tU(knN|DOMg{IB}=&{|?Te*^t5#{>q) zL`BBMhR3DY*=YU){g;LRy_g#RB0%RPKJO(m{;U2y0TnT$za|heHzhp&m+$Os^!^|1 z{+m@yR7y%r_%99%$-`lNMD~BxzvnPQ%<6tN&xSy zvLU^<|3-pN&Onn!2eJ5+A#`dg<0El`9rmBatzXFVdX-|b=(rB=S+C9a&6R?MiwZRT zlDWl|TkO**f+3#$IsdAHl?mBcuOExUPd4)%0|R(}Ye~$|ux1JAsa&ObJ8@qfM5snF zJ8!U*S1XU_vEH&6+4ZSuRKIS)`G}po$2x;N5mfM~MWgr|+YAyBHVSVC7xISa+w9}n z<$`-TZoH!37nX3iKkpqbi{4)c3udcYAu&)5Rr^axW%nBX@}eqn%ue9hMdO5x_ZPzL zl^s8N+ZX4KjpV@!aya*MIVai++3iVJNN2Vs)AzOI1M1}nGjik)1X9Arvfo**{!~kY z(*C^3=R4afHi6fi(8E0)C1FQqF+c9?iZSEF`O0UDdE)6?B>T}Dl6+eVAu}gI+-5!hF*FZ1pRMEF=}|oMMufok zU>vcv&?nymhGJK3ECL{8LB%84qpD)=CK1Q|){G*@J}pAaEp@@GgHzdsUSEk$Z++r) zyoHP^bml&}Yk0>aJIECjahHbmLd#A)$d@GHNY_C!F{XqcI{$;XiKnyvdyDvoxO&0G z-EmyEHjf_~xPY4~Z6_5b`LKDqmpBJaL^>VMGE!$+<`9Z3sj09#)OEy9EEjx}2U0EN zner12!JaFULZhvB*`}WH=zU3v{V2M~MuaUVt7=vVEV_>~!|$rhytog^b=G6a&&v=V zl8j~IyU?|uiVVqo#5vI}cA!1kO1JK3?4&GB=Dy2e~`;()PG z%K10Jx5sMO(rAgq-zwPk<^p!LP#wp0s@UM}J?wLR0^3vKMkd%U7yNc5oMmJQg_qYV zv7~9W*b`nu=I<=PrOaHM@^~V6^YMJsD)UUZPY5USc?!5Za0NVLw+fc#EkjbtZWi8o zzcJh+g-wpz&nnlnF)7VFJhS_a)eP{!XkKY?qv2kY$%rgy+Bh;(#UfO`-pveWca!7W zD0{ymw<)b?6;l&4gZ+ui%pk;=>{)+Gs~;oP=|zz8Scs*H zu1wKs8`FB6%bwgi%?7RBNDPCgvz_+CQQITW0^HW&xW{sgPW~jkX%|P99nlqDNs@J+x-SUbBu`7G-!Dq`*8!vJ(!JX)uP-;b#v!bX~H-m{uLJ@b_5J+osS%7*N#OE4Kz zcZt27ya9__!{D*IhlKUp*wlT2l0mmF3d2SpW5Yn5M-O1xnGKklD@`jL zbKpL2IkeTKX!+`z0SZg~85py(<)oOO*`t zrPUP1TdEeh=P=;K;*G+583Vr`xH zEcO#@@o2hPX9rI1wWANk1>pN~aN%MCT=N|8!hA3l#9QL@G7Ea_v_khuyH?8R-?=(j^-2`(I@3wX7g%7VEDzJ$$P-lc z`JT2&^cm>P*Y><5WdX@JP_iP+cS^o8_MRi>e0&Wkf6_=Jq!d{*~6W?6rUjlJT=(HdyLA9+%3BV4Lp1ryl+xJiI4_D=B;@ZVQW=?sa?M z{*y{)nAH)sm?E}q4%`K!U;t_wWAXQ27zMwTAzPr~zapf-59ATIV0 z`DqY`mf#$Iy`MiZDW}-|HJYsWS;@6cxAXM0PbBQ|8fGwc0YB(g!riKT*wC(}{Fy~G z*H?ValHV7iwQ(NatqK)tof*nczCFbz)y#s**g<@csTUukZ;BV=1_+i2M&hZiF7}V` zVqKcQiF`Uw@MPR_^8ImtoHg1AxfTKAU+2JUt`zCILWs0VD0-G_B4m6KZeG+ws_s?c zj2?A%t8+4Ia9@M(?vt4~X=1(8WSN}pTA}+xPodfeFT`%Wf8@%O@pw>Q1&y7|a^2yx zOd+{mpi*{|)lC}=r|>vo)bTCk+n0_bXR!&d?ii49+mHlVP=hd5G;JFCL#y7L_iYC??dqfz0MojRdOb2Az z9Wq}#7b{-hWV)Yv5sf4*cIALN-aJfTt99mJ%lZn`_k0tkOS(cUd^5Wos>z117COAg()#Q+MA%8fIYZkW<25 zvn<(0r_E$x>l{3Lm4e_evB)2$$*caf6rNjNf7!u=*$dgA zE_+72stE}{-83V2BlLc|%3^X;an$W%(^L5i%t*5yIyQy zvg^m;sP9fh*GI8W2`_~m?Tv_N(1BzXCC#SGan0f*`SigD7u&l8y1Pya4Su+@TT(*7 zxaG6iV(GrD$e^3(d|e86+HNs#x;I|W@T93_c63<7NrW5ZLba_OW7b{7In7A;M;p?% z8Jg&joB6T=l{RWk#&k7RDpOjFmO*Y*yFd?T_79`=mJX=hmXAth4-|Y(LB5SK zYF8e?EBAr)>Iy=eM$F zK|5(GE;lcLUQIcM51T@J)z1Sn3!-5UZ0X#K=CHqPkA81N^TV#A7_s94diLnx{Ao9| z-v>Q4EEPA{EVSmy)9|#(a5@o0^|q?hCFTiK>&j_3rfE{WQTjNdGMkoedII^?pqt7~ z=;~MFG5S;}J*J+I>C^xZtRiX7DJS?z z)2Qdt)yNi>;+mo!>^f)S`%lov2h?!wOBgg@6TR9UowpEB!EQ8tr z9o&9zfwXTKC^OiIHSe@wsAK}KHhXBw-N&sFr_ehvk|y|Q;Ehu?8tu#=Z*maj`&NiJ zJ;scpW0+uJLV0Wp9^O$vak47pEmZKX#019kG9Wp~5O>wj!cIB|o8vt2?57`9P(1|8 z16I%&u>f1G73uGL`qA9q%~5Cc0`g8$H2H!d4GVsZ=GTsNa}cFIN7M1n$Q!2u1fq39 z1NyCcfPuz8v8UVwhZk32y?F$^GgBWApPgV<5h|og<2G|xn8Vr&7V+h2`$*WGI4&k{ z&J7NwV&BOQ@~)Q?WWBUe8`Uj%yWlH(YNLdG9$9c18Hv_31x$G-`cG9H734dm^ZVa+ z3BtSQLEjmCOj4Yd_pR@Vb=@?yqKLl;P39xJ3{XG1jkvEYK-Pq4UZPuTd3=x^Ulgy73PCKZ-*Au} z)_3A#t}&wJ?7&^GL!jnr3s}YTj1nVcpWVmj_Oi!_kJ3C%x?JGz8UmcmBE!Ds@Urvj zSbMpg={|pJac{XQe-&U)x{dRBYqCCzQa6OK!wnw|X7G;INIrS#JLcYgf`m<*N0$1G zAfb{GyrYjSYAmw(a(@A|Ugt9#EiGhqz7h^J^ye{~s+q5=I8s#Nxc~GDvg2wd%pINh zmW}E#tGCC78-C~*{gjkUtKbulmk^8E0qkDBE6Yie-~%p?BMq&}eADeqWGWZUwXf&# zLk8*SnqPp0VnSwk!xN>!(cI&WJRVgDn0nzN(oRnCr{d%J$d{QY*U$q=m>~R+w1;Om z?jmh%a#(XZA4L~(xbIY3Uj1|dzq4Qw3^eU{js0R?xy6}l%B&!T?bSr`yfLdQ%;j@q z-T4Z;eeC@CaXegar)BJ`aJ-D(CVcSYE(zH^f-~P&;eV?2uyToGV`eKTm4>2Ayh>ekKXVufxShP|)3vPSJ(m$EN-gg@l zHq;8W@}@8+&q`*Z^^Lv!euf>R)8O3chfiTkgwqzj7gUegiFXr=aCSgBJ2Bx@(}uh% zM$9>j?-jugxKC%3hZ_o}&si>fdwUz{srW3M)~wnjZS;X$7_@+0RT7=05P7_e-NstP znwf@!z91}YH&zIROet>>Iu~wZ<$C+rCpAYlSG0CqQ+m2zy8W0qn-ybL82M4a9Qj zJ)utjE>b47fN7MNiuNBLgwo0-q-F2#Ox;SBWxk#XYvqonHP_?>3-TNh@_amN?li_c z(fV87{|gzIQOyFzZ-?=Nk??b<7G6tU$Yx)C!(MDzh4KgU*`1+vjU!@1gg;(4;gO{U zN_50=TwaOPI87#rKTncAqWNw7i%-llv5R$2I?lYEcVJeA49WbY$L#&;n>MFRLFv&a z!pA-R1?MO4W{1^`$doX7Yzh61jWXPdspmqFmtTtI4vX-ixruq431rL0;nXpJHD9XNQ6JK9k)QgE(a2I zGilThFO&@%hmaYg>DklG*eBsbQ#LL}X?h1D4sJ!wvpg6F9l|b)OZaegB4m3j(owhG zBeansS33=Ad#%veY=}}1Yr3_!9vxT_0I?klabCe1{$3Yw+kPvY`>%paFE60u2plAL zz1KrqB9RLH4)}5+?7&u4{_) zadjF@s<(pI)j~Z@l0LAkf=^>VDki=L{miD|T>!zk>*2J|aeex2cotf(yQ5`&GEQji zfcYmw`h189Ha}FN$9%>hE65Lh1%UFEeb~`ZhSOEE;F9GA!>cvOuCT%jF>5HF*QX5& z?SKkN?AoIO$-`zi_(Xv|tu};9qcxb%A?$t;07K;@cs-D$b^A7<)XErbC-mqyOg zxj=Gm61<5yG(A1(C0S?qKLr9_RDeEFrr~8Lak^UzvU3HfQ{IBB@6D+7Bui|_>W}Zg z>Ee2U3RSlXpl=6K>NqJ79fMTq3O7UANB9+$JIe9w-6AXqYrx)!k@V2BQS^LEe_C9o zNH0%!!1zja3_pAqr$zt41mlmeed~op**O^Su^%06TnU@_c`&*hL?6d4h3DgBsLu_c zJ;A9M+_)OWv4!Y&Rh1r^*?^$iFne&kkT+z1;g<|a=0v2@Hz7>-jm&)Det+e}cZ z$fccj@N;{^=Z)6wDQjh%DrrT^Di9 z@+eGdOJcLShazMAYclWRRG$6Rl!ru>k-#D=ZY4SsO%0I7kycYkZ;Iq$iv?USGaHMK zK4Ee65}9AHj%W$;&=4^iY91>HQy7E7NuhjQ)+44UErZo^dHB7=oIm^N$Ti003FPm_ zke$` z@jXnJ8wRaF=Zt(lCtZ^Fz0}0b=6myPc40h;uI4RY`*RKJc>YO$0eP`&Irr}9#bo^Q zxlk~KbqI8jdrE-^tyo6FUyX(Kgebnh*K~d_T!$SUw3DAMHsaq+qqyy+Lg7A*vCu93 zKwd}8;xlcHxLw0zB6}#8@5voQxK9;N?%Ge})b)92ekJdIokxzUOy%bz7vh>zId_wZ z<0WJTt5opgvoqz%^+7T)D2PV6uM>}ca+@veRW3{)I+I5`&qwN+#XR1?3~$4iV25!j zf3esLHp?uytfU_ge>R)R30(NJgaH2fXecjlEZ|rh#g9JC=ik=%!Gr)8f!ekuT#%r| z+b5l4W88yye0DuKI<;63QRu|g#u=i|XGJ7MFGE>L5U+}JN5-j6((=*_pT55$M(h=l zG-u?*s`-5J7FpbVF3T&pfUCJ2W{aA0x!#>Tb~QB-I{wmpf@C?7vkc?jodM*uXs=93 zF-v&SpJyHS<=!EsmQn}n1QAb;2(6l=aXnoV$4p)EJh&FKWjuu@a|m)TWRk11*5mAV z8^PRV6JZ@MfdTzD5_`)n?9852EW4P2(4$k?*2WF&eP9^tC|}F&gsp~VNGWFYYG5C& zWDw~S0PmDw;*eX$IwO1$Bv~(rEgnbaZ!v&#$Z)}yb{kguerZ!z&@Dl1@Do;io-scQ z3jGI{ELA+i$jWd@CbNHqU|0TpJUlU<&DMz)s0I~an@$A_J|W3%yWR;?1}?)oomzJP zxEo4E-=*-!M&Y!@vP9YGG`o05L11q1PHLlAn$N8-Wk!>o_z22`vHA}{aufs@e&_T&6Cw)IXwyc7A^WcmWe#J>_P ziKWm`kz`NgxZ$QI@MB$*_#V{t9aMoLH6p1OqO zUvkIzNdwTLbDWe;6k~6!c9Z%#N9NeiPgrv$3gxfESV^w|#PmQuyn3XWaON&jU>%I3 zRVAb@(H=1qc0x&5j6=R*tZm>f*4;o6V;u&5fxxIm7Kv6bnRw+rc4f|GfooJMi&IY$ zPND$<&!dKx@-13|NAcaHF5iSz>zgrK({s#S=_o5QbAjGcA(@u$C-gNLfYCRdS+aIB z@$DOl{=5|3>)cH2z=(GTup6Nn2 ztR|TIq7EV7W1%RUhBv(qakDi&tPc1KtYQBtkgoFRg4`%)q&_Y| zs+KJb^^0I6ZGwW|Rneoq9!olbLqXp3T)s3d(r|^__KBE2TaJ!3kEiZ3i$yW!ikHL2 zQ%mJatV~;m&Yzz!)4c!&3nyUPN?AG}v~Ji9O!o+Nt$MFK{HF8sAt+Bx+GMKZVM?ulaU)$ySx%DcO9r-uWppeG@b*!>}iFQYnAB{lVEJ#0X(cUrbeS(U|OwCr;Qs- z`xW=0&!-)MPDTV>YpP4_I{j$kqkI&48&Iq6B`}aXh)vVHsk-QVf-B~rt#yj%-=2%? znhNZ)Q=n4WGm#eVhM|)^P+}uN3soGDf7=Sza>pV;X)4CI%!1O*Y^-*(rqY|WF-5;G zO&_I7yN7;>4>fdL!mBohqb!PiNih+TDk=mXh^UK2J4Usec@A-ROF^#{rp!bjDd2)DC zlt@945&TOG5f~L_XQTEP2LGM7|5s~hVrDea(A?a_&~T!Ov8nL{10z#|i3Y~zqCPWI z(}^Z#qDD047t6vHrH-+p?EbIn>L~t;@PCvee^ghD{B`5M%Kxe(`7_iX)f8j?3+AKf z>Mx=GB+mFV;veM^TK|RUCX#IYHRAu!Wc(TQk9r4DVEmz6`8&}6=q&z>^~VSs@dqR^ zF;9_~|B~Q8^%8#u`ePs|{}+&f+TQ^ED?I;<^v92_`(H@Q{|@P|!_d{SPoHuB0#K0{ z{~xGdilATR->uF+;z>+w;^=?&i=O+(t&o4kfXJz$sFVC9FA+T>CH8ClKga$bf(voi literal 0 HcmV?d00001 From 97e70a7249ca9874f7842c149eba8e16cf92c8fd Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 18:38:42 +0000 Subject: [PATCH 13/31] Split secondary network outputs into full-field and residual lists Replace the single secondary_residual_out_names with two lists: - secondary_out_names: full-field outputs used directly (must not overlap with out_names) - secondary_residual_out_names: residual corrections added to the backbone output (if name is in out_names) or to the normalized input (if name is only in in_names) Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/step/single_module.py | 83 +++++++++++++++++++--------------- fme/core/step/test_step.py | 70 ++++++++++------------------ 2 files changed, 71 insertions(+), 82 deletions(-) diff --git a/fme/core/step/single_module.py b/fme/core/step/single_module.py index 09c330674..cefe87739 100644 --- a/fme/core/step/single_module.py +++ b/fme/core/step/single_module.py @@ -56,12 +56,13 @@ class SingleModuleStepConfig(StepConfigABC): residual_prediction: Whether to use residual prediction. secondary_builder: Optional builder for a secondary network that receives the same input as the primary module. - secondary_out_names: Names of variables output by the secondary network. - Names that overlap with out_names must appear in - secondary_residual_names. - secondary_residual_names: Names of variables (a subset of both out_names - and secondary_out_names) for which the secondary network's output is - added as a residual to the primary module's output. + secondary_out_names: Names of variables output by the secondary network + as full fields (used directly as output). Must not overlap with + out_names. + secondary_residual_out_names: Names of variables for which the secondary + network predicts a residual correction. If the name is also in + out_names, the residual is added to the backbone's output; + otherwise it is added to the (normalized) input value. """ builder: ModuleSelector @@ -78,7 +79,7 @@ class SingleModuleStepConfig(StepConfigABC): residual_prediction: bool = False secondary_builder: ModuleSelector | None = None secondary_out_names: list[str] = dataclasses.field(default_factory=list) - secondary_residual_names: list[str] = dataclasses.field(default_factory=list) + secondary_residual_out_names: list[str] = dataclasses.field(default_factory=list) def __post_init__(self): self.crps_training = None # unused, kept for backwards compatibility @@ -97,14 +98,16 @@ def __post_init__(self): raise ValueError( f"next_step_forcing_name is an output variable: '{name}'" ) - all_out_names = set(self.out_names) | set(self.secondary_out_names) + all_secondary_names = set(self.secondary_out_names) | set( + self.secondary_residual_out_names + ) if self.secondary_decoder is not None: for name in self.secondary_decoder.secondary_diagnostic_names: if name in self.in_names: raise ValueError( f"secondary_diagnostic_name is an input variable: '{name}'" ) - if name in all_out_names: + if name in set(self.out_names) | all_secondary_names: raise ValueError( f"secondary_diagnostic_name is an output variable: '{name}'" ) @@ -114,36 +117,38 @@ def __post_init__(self): "secondary_out_names must be empty when " "secondary_builder is not provided" ) - if self.secondary_residual_names: + if self.secondary_residual_out_names: raise ValueError( - "secondary_residual_names must be empty when " + "secondary_residual_out_names must be empty when " "secondary_builder is not provided" ) else: - if not self.secondary_out_names: + if not self.secondary_out_names and not self.secondary_residual_out_names: raise ValueError( - "secondary_out_names must not be empty when " + "at least one of secondary_out_names or " + "secondary_residual_out_names must be non-empty when " "secondary_builder is provided" ) - for name in self.secondary_residual_names: - if name not in self.secondary_out_names: - raise ValueError( - f"secondary_residual_name '{name}' must be in " - f"secondary_out_names: {self.secondary_out_names}" - ) - if name not in self.out_names: - raise ValueError( - f"secondary_residual_name '{name}' must be in " - f"out_names: {self.out_names}" - ) overlap = set(self.secondary_out_names) & set(self.out_names) - if overlap != set(self.secondary_residual_names): + if overlap: + raise ValueError( + f"secondary_out_names must not overlap with out_names. " + f"Overlap: {overlap}" + ) + overlap = set(self.secondary_out_names) & set( + self.secondary_residual_out_names + ) + if overlap: raise ValueError( - f"Names appearing in both out_names and secondary_out_names " - f"must be listed in secondary_residual_names. " - f"Overlap: {overlap}, " - f"secondary_residual_names: {set(self.secondary_residual_names)}" + f"secondary_out_names must not overlap with " + f"secondary_residual_out_names. Overlap: {overlap}" ) + for name in self.secondary_residual_out_names: + if name not in self.out_names and name not in self.in_names: + raise ValueError( + f"secondary_residual_out_name '{name}' must be in " + f"out_names or in_names: {self.out_names}, {self.in_names}" + ) @property def n_ic_timesteps(self) -> int: @@ -209,6 +214,7 @@ def output_names(self) -> list[str]: set(self.out_names) .union(secondary_decoder_names) .union(self.secondary_out_names) + .union(self.secondary_residual_out_names) ) @property @@ -321,15 +327,16 @@ def __init__( dist = Distributed.get_instance() if config.secondary_builder is not None: + all_secondary_names = ( + config.secondary_out_names + config.secondary_residual_out_names + ) secondary_module = config.secondary_builder.build( n_in_channels=n_in_channels, - n_out_channels=len(config.secondary_out_names), + n_out_channels=len(all_secondary_names), dataset_info=dataset_info, ) self.secondary_module: Module | None = secondary_module.to(get_device()) - self.secondary_out_packer: Packer | None = Packer( - config.secondary_out_names - ) + self.secondary_out_packer: Packer | None = Packer(all_secondary_names) else: self.secondary_module = None self.secondary_out_packer = None @@ -436,11 +443,13 @@ def network_call(input_norm: TensorDict) -> TensorDict: secondary_dict = self.secondary_out_packer.unpack( secondary_tensor, axis=self.CHANNEL_DIM ) - for name in self._config.secondary_residual_names: - output_dict[name] = output_dict[name] + secondary_dict[name] for name in self._config.secondary_out_names: - if name not in self._config.secondary_residual_names: - output_dict[name] = secondary_dict[name] + output_dict[name] = secondary_dict[name] + for name in self._config.secondary_residual_out_names: + if name in output_dict: + output_dict[name] = output_dict[name] + secondary_dict[name] + else: + output_dict[name] = input_norm[name] + secondary_dict[name] secondary_output_dict = self.secondary_decoder.wrap_module(wrapper)( output_tensor.detach() # detach avoids changing base outputs ) diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index 8b281265f..90c0a7424 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -402,8 +402,8 @@ def get_single_module_with_secondary_selector( out_names=["prog_a", "prog_b"], normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=["prog_a", "diag_a"], - secondary_residual_names=["prog_a"], + secondary_out_names=["diag_a"], + secondary_residual_out_names=["prog_a"], ), ), ) @@ -716,56 +716,56 @@ def test_secondary_builder_none_with_out_names_raises(): ) -def test_secondary_builder_none_with_residual_names_raises(): +def test_secondary_builder_none_with_residual_out_names_raises(): normalization = get_network_and_loss_normalization_config( names=["a", "b"], ) - with pytest.raises(ValueError, match="secondary_residual_names must be empty"): + with pytest.raises(ValueError, match="secondary_residual_out_names must be empty"): SingleModuleStepConfig( builder=ModuleSelector(type="MLP", config={}), in_names=["a"], out_names=["b"], normalization=normalization, - secondary_residual_names=["b"], + secondary_residual_out_names=["b"], ) -def test_secondary_builder_with_empty_out_names_raises(): +def test_secondary_builder_with_empty_names_raises(): normalization = get_network_and_loss_normalization_config( names=["a", "b"], ) - with pytest.raises(ValueError, match="secondary_out_names must not be empty"): + with pytest.raises(ValueError, match="at least one of"): SingleModuleStepConfig( builder=ModuleSelector(type="MLP", config={}), in_names=["a"], out_names=["b"], normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=[], ) -def test_secondary_residual_name_not_in_secondary_out_names_raises(): +def test_secondary_out_name_overlaps_out_names_raises(): normalization = get_network_and_loss_normalization_config( - names=["a", "b", "c"], + names=["a", "b"], ) - with pytest.raises(ValueError, match="secondary_residual_name 'b'.*secondary_out"): + with pytest.raises(ValueError, match="secondary_out_names must not overlap"): SingleModuleStepConfig( builder=ModuleSelector(type="MLP", config={}), in_names=["a"], out_names=["b"], normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=["c"], - secondary_residual_names=["b"], + secondary_out_names=["b"], ) -def test_secondary_residual_name_not_in_out_names_raises(): +def test_secondary_out_name_overlaps_residual_out_names_raises(): normalization = get_network_and_loss_normalization_config( names=["a", "b", "c"], ) - with pytest.raises(ValueError, match="secondary_residual_name 'c'.*out_names"): + with pytest.raises( + ValueError, match="secondary_out_names must not overlap.*residual" + ): SingleModuleStepConfig( builder=ModuleSelector(type="MLP", config={}), in_names=["a"], @@ -773,48 +773,28 @@ def test_secondary_residual_name_not_in_out_names_raises(): normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), secondary_out_names=["c"], - secondary_residual_names=["c"], + secondary_residual_out_names=["c"], ) -def test_secondary_out_name_overlaps_out_name_without_residual_raises(): +def test_secondary_residual_out_name_not_in_out_or_in_names_raises(): normalization = get_network_and_loss_normalization_config( names=["a", "b", "c"], ) - with pytest.raises(ValueError, match="secondary_residual_names"): + with pytest.raises(ValueError, match="secondary_residual_out_name 'c'"): SingleModuleStepConfig( builder=ModuleSelector(type="MLP", config={}), in_names=["a"], out_names=["b"], normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=["b", "c"], - secondary_residual_names=[], - ) - - -def test_secondary_out_name_overlaps_secondary_decoder_raises(): - normalization = get_network_and_loss_normalization_config( - names=["a", "b", "c"], - ) - with pytest.raises(ValueError, match="secondary_diagnostic_name is an output"): - SingleModuleStepConfig( - builder=ModuleSelector(type="MLP", config={}), - in_names=["a"], - out_names=["b"], - normalization=normalization, - secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=["c"], - secondary_decoder=SecondaryDecoderConfig( - secondary_diagnostic_names=["c"], - network=ModuleSelector(type="MLP", config={}), - ), + secondary_residual_out_names=["c"], ) @pytest.mark.parallel -def test_secondary_network_residual_adds_to_backbone(): - """Test that secondary_residual_names outputs are added to backbone outputs.""" +def test_secondary_network_full_field_and_residual(): + """Test secondary_out_names and secondary_residual_out_names together.""" torch.manual_seed(0) normalization = get_network_and_loss_normalization_config( names=["forcing", "prog", "diag"], @@ -835,8 +815,8 @@ def test_secondary_network_residual_adds_to_backbone(): out_names=["prog"], normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=["prog", "diag"], - secondary_residual_names=["prog"], + secondary_out_names=["diag"], + secondary_residual_out_names=["prog"], ), ), ) @@ -880,8 +860,8 @@ def test_secondary_network_state_round_trip(): out_names=["prog"], normalization=normalization, secondary_builder=ModuleSelector(type="MLP", config={}), - secondary_out_names=["prog", "diag"], - secondary_residual_names=["prog"], + secondary_out_names=["diag"], + secondary_residual_out_names=["prog"], ) img_shape = DEFAULT_IMG_SHAPE step1 = get_step( From 708d634aced2d91547da9fc47f96896c255a3f6d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 18:42:47 +0000 Subject: [PATCH 14/31] Add isotropic noise support for LocalNet Move isotropic_noise and noise generator factory into noise_conditioned.py as shared infrastructure. LocalNetBuilder gains a noise_type field ('gaussian' or 'isotropic'); when isotropic, it creates a standalone InverseRealSHT for noise generation. The SFNO builder uses the same make_noise_generator factory, extracting the iSHT from its model. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 30 +++++++-- fme/ace/registry/noise_conditioned.py | 80 ++++++++++++++++++++++++ fme/ace/registry/stochastic_sfno.py | 74 +++------------------- fme/ace/registry/test_local_net.py | 15 +++++ fme/ace/registry/test_stochastic_sfno.py | 3 +- 5 files changed, 130 insertions(+), 72 deletions(-) diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index 130f015d2..742310b49 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -4,7 +4,10 @@ import torch from torch import nn -from fme.ace.registry.noise_conditioned import NoiseConditionedModule +from fme.ace.registry.noise_conditioned import ( + NoiseConditionedModule, + make_noise_generator, +) from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo from fme.core.models.conditional_sfno.ankur import ( @@ -111,7 +114,10 @@ class LocalNetBuilder(ModuleConfig): Attributes: embed_dim: Dimension of the embeddings. - noise_embed_dim: Dimension of the gaussian noise conditioning channels. + noise_embed_dim: Dimension of the noise conditioning channels. + noise_type: Type of noise for conditioning ('gaussian' or 'isotropic'). + Isotropic noise is generated via inverse spherical harmonic + transform. context_pos_embed_dim: Dimension of the learned positional embedding used for conditioning. 0 disables. block_types: List of filter types for each block ('disco', 'conv1x1'). @@ -132,11 +138,12 @@ class LocalNetBuilder(ModuleConfig): lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank if None. data_grid: Grid type for spherical harmonic transforms used by - DISCO convolutions. + DISCO convolutions and isotropic noise generation. """ embed_dim: int = 256 noise_embed_dim: int = 256 + noise_type: Literal["gaussian", "isotropic"] = "gaussian" context_pos_embed_dim: int = 0 block_types: list[BlockType] = dataclasses.field( default_factory=lambda: ["disco"] * 12 @@ -190,10 +197,25 @@ def build( data_grid=self.data_grid, context_config=context_config, ) + img_shape = dataset_info.img_shape + if self.noise_type == "isotropic": + from torch_harmonics import InverseRealSHT + + grid_mapping = { + "legendre-gauss": "legendre-gauss", + "equiangular": "equiangular", + } + isht = InverseRealSHT(*img_shape, grid=grid_mapping[self.data_grid]) + noise_generator = make_noise_generator( + self.noise_type, isht=isht, lmax=isht.lmax, mmax=isht.mmax + ) + else: + noise_generator = make_noise_generator(self.noise_type) return NoiseConditionedModule( net, - img_shape=dataset_info.img_shape, + img_shape=img_shape, embed_dim_noise=self.noise_embed_dim, embed_dim_pos=self.context_pos_embed_dim, embed_dim_labels=embed_dim_labels, + noise_generator=noise_generator, ) diff --git a/fme/ace/registry/noise_conditioned.py b/fme/ace/registry/noise_conditioned.py index 8bdc37718..7052aaa47 100644 --- a/fme/ace/registry/noise_conditioned.py +++ b/fme/ace/registry/noise_conditioned.py @@ -1,6 +1,8 @@ """Noise-conditioned wrapper for modules with a Context-based forward signature.""" +import math from collections.abc import Callable +from typing import Literal import torch from torch import nn @@ -24,6 +26,84 @@ def gaussian_noise(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: ) +def isotropic_noise( + leading_shape: tuple[int, ...], + lmax: int, + mmax: int, + isht: Callable[[torch.Tensor], torch.Tensor], + device: torch.device, +) -> torch.Tensor: + """Generate isotropic noise on the sphere via inverse SHT. + + Draws random spherical harmonic coefficients scaled so that Var[f(θ,φ)] = 1. + """ + coeff_shape = (*leading_shape, lmax, mmax) + real = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag[..., :, 0] = 0.0 # m = 0 => purely real + + sqrt2 = math.sqrt(2.0) + real[..., :, 1:] /= sqrt2 + imag[..., :, 1:] /= sqrt2 + + scale = math.sqrt(4.0 * math.pi) / lmax + alm = (real + 1j * imag) * scale + + l_slice, m_slice = Distributed.get_instance().get_local_slices((lmax, mmax)) + alm = alm[..., l_slice, m_slice] + + return isht(alm) + + +def make_isotropic_noise_generator( + isht: Callable[[torch.Tensor], torch.Tensor], + lmax: int, + mmax: int, +) -> NoiseGenerator: + """Create an isotropic noise generator from an inverse SHT. + + Args: + isht: Inverse spherical harmonic transform callable. + lmax: Maximum degree of spherical harmonics. + mmax: Maximum order of spherical harmonics. + """ + + def _isotropic(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: + return isotropic_noise( + (x.shape[0], embed_dim_noise), + lmax, + mmax, + isht, + device=x.device, + ) + + return _isotropic + + +def make_noise_generator( + noise_type: Literal["isotropic", "gaussian"], + isht: Callable[[torch.Tensor], torch.Tensor] | None = None, + lmax: int = 0, + mmax: int = 0, +) -> NoiseGenerator: + """Create a noise generator from a noise type string. + + Args: + noise_type: Type of noise to generate. + isht: Inverse SHT callable, required for isotropic noise. + lmax: Maximum degree, required for isotropic noise. + mmax: Maximum order, required for isotropic noise. + """ + if noise_type == "gaussian": + return gaussian_noise + elif noise_type == "isotropic": + if isht is None: + raise ValueError("isht is required for isotropic noise") + return make_isotropic_noise_generator(isht, lmax, mmax) + else: + raise ValueError(f"Invalid noise type: {noise_type}") + + class NoiseConditionedModule(nn.Module): """Wraps a context-based module with noise conditioning. diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 473cd93e0..5d07d0d91 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -1,18 +1,12 @@ import dataclasses -import math -from collections.abc import Callable from typing import Literal -import torch - from fme.ace.registry.noise_conditioned import ( NoiseConditionedModule, - NoiseGenerator, - gaussian_noise, + make_noise_generator, ) from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo -from fme.core.distributed.distributed import Distributed from fme.core.models.conditional_sfno.sfnonet import ( ContextConfig, SFNONetConfig, @@ -23,65 +17,6 @@ ) -def isotropic_noise( - leading_shape: tuple[int, ...], - lmax: int, # length of the ℓ axis expected by isht (global) - mmax: int, # length of the m axis expected by isht (global) - isht: Callable[[torch.Tensor], torch.Tensor], - device: torch.device, -) -> torch.Tensor: - # --- draw independent N(0,1) parts -------------------------------------- - coeff_shape = (*leading_shape, lmax, mmax) - real = torch.randn(coeff_shape, dtype=torch.float32, device=device) - imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) - imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real - - # m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1 - sqrt2 = math.sqrt(2.0) - real[..., :, 1:] /= sqrt2 - imag[..., :, 1:] /= sqrt2 - - # --- global scale that makes Var[T(θ,φ)] = 1 --------------------------- - scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax) - alm = (real + 1j * imag) * scale - - # --- for distributed iSHT, slice to local spectral extent -------------- - l_slice, m_slice = Distributed.get_instance().get_local_slices((lmax, mmax)) - alm = alm[..., l_slice, m_slice] - - return isht(alm) - - -def _make_sfno_noise_generator( - noise_type: Literal["isotropic", "gaussian"], - conditional_model: ConditionalSFNO, -) -> NoiseGenerator: - """Create a noise generator for an SFNO model. - - For gaussian noise, returns the default generator. For isotropic noise, - returns a generator that uses the SFNO's inverse spherical harmonic - transform. - """ - if noise_type == "gaussian": - return gaussian_noise - elif noise_type == "isotropic": - - def _isotropic(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: - lmax = conditional_model.itrans_up.lmax - mmax = conditional_model.itrans_up.mmax - return isotropic_noise( - (x.shape[0], embed_dim_noise), - lmax, - mmax, - conditional_model.itrans_up, - device=x.device, - ) - - return _isotropic - else: - raise ValueError(f"Invalid noise type: {noise_type}") - - def NoiseConditionedSFNO( conditional_model: ConditionalSFNO, img_shape: tuple[int, int], @@ -97,7 +32,12 @@ def NoiseConditionedSFNO( embed_dim_noise=embed_dim_noise, embed_dim_pos=embed_dim_pos, embed_dim_labels=embed_dim_labels, - noise_generator=_make_sfno_noise_generator(noise_type, conditional_model), + noise_generator=make_noise_generator( + noise_type, + isht=conditional_model.itrans_up, + lmax=conditional_model.itrans_up.lmax, + mmax=conditional_model.itrans_up.mmax, + ), ) diff --git a/fme/ace/registry/test_local_net.py b/fme/ace/registry/test_local_net.py index c3cf78191..3d07d8e62 100644 --- a/fme/ace/registry/test_local_net.py +++ b/fme/ace/registry/test_local_net.py @@ -155,6 +155,21 @@ def test_local_net_noise_produces_stochastic_output(): assert not torch.allclose(out1, out2) +def test_local_net_isotropic_noise(): + n_in, n_out = 3, 2 + dataset_info = _get_dataset_info() + builder = LocalNetBuilder( + embed_dim=16, + noise_embed_dim=8, + noise_type="isotropic", + block_types=["disco", "disco"], + ) + module = builder.build(n_in, n_out, dataset_info).to(fme.get_device()) + x = torch.randn(2, n_in, *IMG_SHAPE, device=fme.get_device()) + out = module(x) + assert out.shape == (2, n_out, *IMG_SHAPE) + + def test_local_net_with_context_pos_embed(): n_in, n_out = 3, 2 dataset_info = _get_dataset_info() diff --git a/fme/ace/registry/test_stochastic_sfno.py b/fme/ace/registry/test_stochastic_sfno.py index 453b1c761..18a235f8b 100644 --- a/fme/ace/registry/test_stochastic_sfno.py +++ b/fme/ace/registry/test_stochastic_sfno.py @@ -4,7 +4,8 @@ import torch from torch_harmonics import InverseRealSHT -from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO, isotropic_noise +from fme.ace.registry.noise_conditioned import isotropic_noise +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO from fme.core.device import get_device from fme.core.models.conditional_sfno.layers import Context From a525a64554535e0db771296cc68dcb98a77b9429 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 18:58:32 +0000 Subject: [PATCH 15/31] Consolidate noise conditioning into NoiseConditionedModel in stochastic_sfno Address review feedback: - Rename NoiseConditionedSFNO to NoiseConditionedModel with an optional inverse_sht parameter (None = gaussian, provided = isotropic noise). Keep NoiseConditionedSFNO as a backward-compatible alias. - Delete noise_conditioned.py; local_net.py imports from stochastic_sfno. - Remove identity grid_mapping dict in LocalNet builder. - Use dist.get_isht() for creating the inverse SHT in LocalNet builder. - Delete make_noise_generator helper to avoid type-unsafe ValueError. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 30 ++-- fme/ace/registry/noise_conditioned.py | 186 ----------------------- fme/ace/registry/stochastic_sfno.py | 184 ++++++++++++++++++---- fme/ace/registry/test_stochastic_sfno.py | 6 +- 4 files changed, 168 insertions(+), 238 deletions(-) delete mode 100644 fme/ace/registry/noise_conditioned.py diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index 742310b49..30bb39332 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -4,12 +4,10 @@ import torch from torch import nn -from fme.ace.registry.noise_conditioned import ( - NoiseConditionedModule, - make_noise_generator, -) from fme.ace.registry.registry import ModuleConfig, ModuleSelector +from fme.ace.registry.stochastic_sfno import NoiseConditionedModel from fme.core.dataset_info import DatasetInfo +from fme.core.distributed import Distributed from fme.core.models.conditional_sfno.ankur import ( AnkurLocalNetConfig, get_lat_lon_ankur_localnet, @@ -199,23 +197,21 @@ def build( ) img_shape = dataset_info.img_shape if self.noise_type == "isotropic": - from torch_harmonics import InverseRealSHT - - grid_mapping = { - "legendre-gauss": "legendre-gauss", - "equiangular": "equiangular", - } - isht = InverseRealSHT(*img_shape, grid=grid_mapping[self.data_grid]) - noise_generator = make_noise_generator( - self.noise_type, isht=isht, lmax=isht.lmax, mmax=isht.mmax - ) + dist = Distributed.get_instance() + inverse_sht = dist.get_isht(*img_shape, grid=self.data_grid) + lmax = inverse_sht.lmax + mmax = inverse_sht.mmax else: - noise_generator = make_noise_generator(self.noise_type) - return NoiseConditionedModule( + inverse_sht = None + lmax = 0 + mmax = 0 + return NoiseConditionedModel( net, img_shape=img_shape, embed_dim_noise=self.noise_embed_dim, embed_dim_pos=self.context_pos_embed_dim, embed_dim_labels=embed_dim_labels, - noise_generator=noise_generator, + inverse_sht=inverse_sht, + lmax=lmax, + mmax=mmax, ) diff --git a/fme/ace/registry/noise_conditioned.py b/fme/ace/registry/noise_conditioned.py deleted file mode 100644 index 7052aaa47..000000000 --- a/fme/ace/registry/noise_conditioned.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Noise-conditioned wrapper for modules with a Context-based forward signature.""" - -import math -from collections.abc import Callable -from typing import Literal - -import torch -from torch import nn - -from fme.core.distributed import Distributed -from fme.core.models.conditional_sfno.layers import Context - -NoiseGenerator = Callable[[torch.Tensor, int], torch.Tensor] -"""Callable (x, embed_dim_noise) -> noise_tensor. - -Takes the input tensor (for shape/device/dtype) and noise embedding dimension, -returns noise of shape [x.shape[0], embed_dim_noise, *x.shape[-2:]]. -""" - - -def gaussian_noise(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: - return torch.randn( - [x.shape[0], embed_dim_noise, *x.shape[-2:]], - device=x.device, - dtype=x.dtype, - ) - - -def isotropic_noise( - leading_shape: tuple[int, ...], - lmax: int, - mmax: int, - isht: Callable[[torch.Tensor], torch.Tensor], - device: torch.device, -) -> torch.Tensor: - """Generate isotropic noise on the sphere via inverse SHT. - - Draws random spherical harmonic coefficients scaled so that Var[f(θ,φ)] = 1. - """ - coeff_shape = (*leading_shape, lmax, mmax) - real = torch.randn(coeff_shape, dtype=torch.float32, device=device) - imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) - imag[..., :, 0] = 0.0 # m = 0 => purely real - - sqrt2 = math.sqrt(2.0) - real[..., :, 1:] /= sqrt2 - imag[..., :, 1:] /= sqrt2 - - scale = math.sqrt(4.0 * math.pi) / lmax - alm = (real + 1j * imag) * scale - - l_slice, m_slice = Distributed.get_instance().get_local_slices((lmax, mmax)) - alm = alm[..., l_slice, m_slice] - - return isht(alm) - - -def make_isotropic_noise_generator( - isht: Callable[[torch.Tensor], torch.Tensor], - lmax: int, - mmax: int, -) -> NoiseGenerator: - """Create an isotropic noise generator from an inverse SHT. - - Args: - isht: Inverse spherical harmonic transform callable. - lmax: Maximum degree of spherical harmonics. - mmax: Maximum order of spherical harmonics. - """ - - def _isotropic(x: torch.Tensor, embed_dim_noise: int) -> torch.Tensor: - return isotropic_noise( - (x.shape[0], embed_dim_noise), - lmax, - mmax, - isht, - device=x.device, - ) - - return _isotropic - - -def make_noise_generator( - noise_type: Literal["isotropic", "gaussian"], - isht: Callable[[torch.Tensor], torch.Tensor] | None = None, - lmax: int = 0, - mmax: int = 0, -) -> NoiseGenerator: - """Create a noise generator from a noise type string. - - Args: - noise_type: Type of noise to generate. - isht: Inverse SHT callable, required for isotropic noise. - lmax: Maximum degree, required for isotropic noise. - mmax: Maximum order, required for isotropic noise. - """ - if noise_type == "gaussian": - return gaussian_noise - elif noise_type == "isotropic": - if isht is None: - raise ValueError("isht is required for isotropic noise") - return make_isotropic_noise_generator(isht, lmax, mmax) - else: - raise ValueError(f"Invalid noise type: {noise_type}") - - -class NoiseConditionedModule(nn.Module): - """Wraps a context-based module with noise conditioning. - - Generates noise and optional positional embeddings (with label-position - interaction), then calls the wrapped module with a fully populated Context. - - Args: - module: An nn.Module with forward signature (x, context: Context). - img_shape: Global spatial dimensions (lat, lon) of the input data. - embed_dim_noise: Dimension of noise channels. - embed_dim_pos: Dimension of learned positional embedding. 0 disables. - embed_dim_labels: Dimension of label embeddings. 0 disables. - noise_generator: Callable that produces noise given the input tensor - and noise embedding dimension. Defaults to gaussian noise. - """ - - def __init__( - self, - module: nn.Module, - img_shape: tuple[int, int], - embed_dim_noise: int = 256, - embed_dim_pos: int = 0, - embed_dim_labels: int = 0, - noise_generator: NoiseGenerator = gaussian_noise, - ): - super().__init__() - self.module = module - self.embed_dim_noise = embed_dim_noise - self.img_shape = img_shape - self._noise_generator = noise_generator - self.label_pos_embed: nn.Parameter | None = None - if embed_dim_pos != 0: - self.pos_embed: nn.Parameter | None = nn.Parameter( - torch.zeros( - 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True - ) - ) - nn.init.trunc_normal_(self.pos_embed, std=0.02) - if embed_dim_labels > 0: - self.label_pos_embed = nn.Parameter( - torch.zeros( - embed_dim_labels, - embed_dim_pos, - img_shape[0], - img_shape[1], - requires_grad=True, - ) - ) - nn.init.trunc_normal_(self.label_pos_embed, std=0.02) - else: - self.pos_embed = None - - def forward( - self, x: torch.Tensor, labels: torch.Tensor | None = None - ) -> torch.Tensor: - x = x.reshape(-1, *x.shape[-3:]) - noise = self._noise_generator(x, self.embed_dim_noise) - - h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) - - embedding_pos: torch.Tensor | None = None - if self.pos_embed is not None: - pos_local = self.pos_embed[..., h_slice, w_slice] - embedding_pos = pos_local.repeat(x.shape[0], 1, 1, 1) - if self.label_pos_embed is not None and labels is not None: - label_local = self.label_pos_embed[..., h_slice, w_slice] - label_embedding_pos = torch.einsum( - "bl, lpxy -> bpxy", labels, label_local - ) - embedding_pos = embedding_pos + label_embedding_pos - - return self.module( - x, - Context( - embedding_scalar=None, - embedding_pos=embedding_pos, - labels=labels, - noise=noise, - ), - ) diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 5d07d0d91..3dab556d0 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -1,44 +1,156 @@ import dataclasses +import math +from collections.abc import Callable from typing import Literal -from fme.ace.registry.noise_conditioned import ( - NoiseConditionedModule, - make_noise_generator, -) +import torch + from fme.ace.registry.registry import ModuleConfig, ModuleSelector from fme.core.dataset_info import DatasetInfo +from fme.core.distributed.distributed import Distributed from fme.core.models.conditional_sfno.sfnonet import ( + Context, ContextConfig, SFNONetConfig, get_lat_lon_sfnonet, ) -from fme.core.models.conditional_sfno.sfnonet import ( - SphericalFourierNeuralOperatorNet as ConditionalSFNO, -) -def NoiseConditionedSFNO( - conditional_model: ConditionalSFNO, - img_shape: tuple[int, int], - noise_type: Literal["isotropic", "gaussian"] = "gaussian", - embed_dim_noise: int = 256, - embed_dim_pos: int = 0, - embed_dim_labels: int = 0, -) -> NoiseConditionedModule: - """Create a noise-conditioned SFNO with support for isotropic noise.""" - return NoiseConditionedModule( - module=conditional_model, - img_shape=img_shape, - embed_dim_noise=embed_dim_noise, - embed_dim_pos=embed_dim_pos, - embed_dim_labels=embed_dim_labels, - noise_generator=make_noise_generator( - noise_type, - isht=conditional_model.itrans_up, - lmax=conditional_model.itrans_up.lmax, - mmax=conditional_model.itrans_up.mmax, - ), - ) +def isotropic_noise( + leading_shape: tuple[int, ...], + lmax: int, # length of the ℓ axis expected by isht (global) + mmax: int, # length of the m axis expected by isht (global) + isht: Callable[[torch.Tensor], torch.Tensor], + device: torch.device, +) -> torch.Tensor: + # --- draw independent N(0,1) parts -------------------------------------- + coeff_shape = (*leading_shape, lmax, mmax) + real = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag = torch.randn(coeff_shape, dtype=torch.float32, device=device) + imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real + + # m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1 + sqrt2 = math.sqrt(2.0) + real[..., :, 1:] /= sqrt2 + imag[..., :, 1:] /= sqrt2 + + # --- global scale that makes Var[T(θ,φ)] = 1 --------------------------- + scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax) + alm = (real + 1j * imag) * scale + + # --- for distributed iSHT, slice to local spectral extent -------------- + l_slice, m_slice = Distributed.get_instance().get_local_slices((lmax, mmax)) + alm = alm[..., l_slice, m_slice] + + return isht(alm) + + +class NoiseConditionedModel(torch.nn.Module): + """Wraps a context-based module with noise and optional label conditioning. + + Generates noise (gaussian by default, or isotropic via an inverse SHT) + and optional positional embeddings (with label-position interaction), + then calls the wrapped module with a fully populated Context. + + Args: + module: An nn.Module with forward signature (x, context: Context). + img_shape: Global spatial dimensions (lat, lon) of the input data. + embed_dim_noise: Dimension of noise channels. + embed_dim_pos: Dimension of learned positional embedding. 0 disables. + embed_dim_labels: Dimension of label embeddings. 0 disables. + inverse_sht: Optional inverse spherical harmonic transform callable. + If provided, isotropic noise is generated via SHT; otherwise + gaussian noise is used. + """ + + def __init__( + self, + module: torch.nn.Module, + img_shape: tuple[int, int], + embed_dim_noise: int = 256, + embed_dim_pos: int = 0, + embed_dim_labels: int = 0, + inverse_sht: Callable[[torch.Tensor], torch.Tensor] | None = None, + lmax: int = 0, + mmax: int = 0, + ): + super().__init__() + self.module = module + self.embed_dim = embed_dim_noise + self.img_shape = img_shape + self._inverse_sht = inverse_sht + self._lmax = lmax + self._mmax = mmax + self.label_pos_embed: torch.nn.Parameter | None = None + # register pos embed if pos_embed_dim != 0 + if embed_dim_pos != 0: + self.pos_embed = torch.nn.Parameter( + torch.zeros( + 1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True + ) + ) + # initialize pos embed with std=0.02 + torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) + if embed_dim_labels > 0: + self.label_pos_embed = torch.nn.Parameter( + torch.zeros( + embed_dim_labels, + embed_dim_pos, + img_shape[0], + img_shape[1], + requires_grad=True, + ) + ) + torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02) + else: + self.pos_embed = None + + def forward( + self, x: torch.Tensor, labels: torch.Tensor | None = None + ) -> torch.Tensor: + x = x.reshape(-1, *x.shape[-3:]) + if self._inverse_sht is not None: + noise = isotropic_noise( + (x.shape[0], self.embed_dim), + self._lmax, + self._mmax, + self._inverse_sht, + device=x.device, + ) + else: + noise = torch.randn( + [x.shape[0], self.embed_dim, *x.shape[-2:]], + device=x.device, + dtype=x.dtype, + ) + + h_slice, w_slice = Distributed.get_instance().get_local_slices(self.img_shape) + + if self.pos_embed is not None: + pos_local = self.pos_embed[..., h_slice, w_slice] + embedding_pos = pos_local.repeat(noise.shape[0], 1, 1, 1) + if self.label_pos_embed is not None and labels is not None: + label_local = self.label_pos_embed[..., h_slice, w_slice] + label_embedding_pos = torch.einsum( + "bl, lpxy -> bpxy", labels, label_local + ) + embedding_pos = embedding_pos + label_embedding_pos + else: + embedding_pos = None + + return self.module( + x, + Context( + embedding_scalar=None, + embedding_pos=embedding_pos, + labels=labels, + noise=noise, + ), + ) + + +# Backward-compatible alias +NoiseConditionedSFNO = NoiseConditionedModel # this is based on the call signature of SphericalFourierNeuralOperatorNet at @@ -199,11 +311,21 @@ def build( embed_dim_labels=len(dataset_info.all_labels), ), ) - return NoiseConditionedSFNO( + if self.noise_type == "isotropic": + inverse_sht = sfno_net.itrans_up + lmax = inverse_sht.lmax + mmax = inverse_sht.mmax + else: + inverse_sht = None + lmax = 0 + mmax = 0 + return NoiseConditionedModel( sfno_net, - noise_type=self.noise_type, embed_dim_noise=self.noise_embed_dim, embed_dim_pos=self.context_pos_embed_dim, embed_dim_labels=len(dataset_info.all_labels), img_shape=dataset_info.img_shape, + inverse_sht=inverse_sht, + lmax=lmax, + mmax=mmax, ) diff --git a/fme/ace/registry/test_stochastic_sfno.py b/fme/ace/registry/test_stochastic_sfno.py index 18a235f8b..e8d23743a 100644 --- a/fme/ace/registry/test_stochastic_sfno.py +++ b/fme/ace/registry/test_stochastic_sfno.py @@ -4,8 +4,7 @@ import torch from torch_harmonics import InverseRealSHT -from fme.ace.registry.noise_conditioned import isotropic_noise -from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO, isotropic_noise from fme.core.device import get_device from fme.core.models.conditional_sfno.layers import Context @@ -37,9 +36,8 @@ def test_noise_conditioned_sfno_conditioning(): n_pos = 8 n_labels = 4 model = NoiseConditionedSFNO( - conditional_model=mock_sfno, + module=mock_sfno, img_shape=img_shape, - noise_type="gaussian", # needed so we don't need a SHT in this test embed_dim_noise=n_noise, embed_dim_pos=n_pos, embed_dim_labels=n_labels, From 7ff1be15f3bfa6538594a4f2afe286b2bedd5706 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 19:02:23 +0000 Subject: [PATCH 16/31] Remove data_grid config from local network builders, hard-code to legendre-gauss Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index 30bb39332..59481ad9d 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -59,8 +59,6 @@ class AnkurLocalNetBuilder(ModuleConfig): pos_embed: Whether to add a learned positional embedding after the first layer. activation_function: Activation function name ('relu', 'gelu', 'silu'). - data_grid: Grid type for spherical harmonic transforms used by - DISCO convolutions. """ embed_dim: int = 256 @@ -68,7 +66,6 @@ class AnkurLocalNetBuilder(ModuleConfig): disco_kernel_size: int = 3 pos_embed: bool = False activation_function: str = "gelu" - data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" def build( self, @@ -94,7 +91,7 @@ def build( in_chans=n_in_channels, out_chans=n_out_channels, img_shape=dataset_info.img_shape, - data_grid=self.data_grid, + data_grid="legendre-gauss", context_config=context_config, ) return _ContextWrappedModule(net) @@ -135,8 +132,6 @@ class LocalNetBuilder(ModuleConfig): lora_rank: Rank of LoRA adaptations. 0 disables LoRA. lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank if None. - data_grid: Grid type for spherical harmonic transforms used by - DISCO convolutions and isotropic noise generation. """ embed_dim: int = 256 @@ -157,7 +152,6 @@ class LocalNetBuilder(ModuleConfig): affine_norms: bool = False lora_rank: int = 0 lora_alpha: float | None = None - data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" def build( self, @@ -192,13 +186,13 @@ def build( in_chans=n_in_channels, out_chans=n_out_channels, img_shape=dataset_info.img_shape, - data_grid=self.data_grid, + data_grid="legendre-gauss", context_config=context_config, ) img_shape = dataset_info.img_shape if self.noise_type == "isotropic": dist = Distributed.get_instance() - inverse_sht = dist.get_isht(*img_shape, grid=self.data_grid) + inverse_sht = dist.get_isht(*img_shape, grid="legendre-gauss") lmax = inverse_sht.lmax mmax = inverse_sht.mmax else: From b806b9a5574fbc130d363cf69993617df5792f5d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 19:07:37 +0000 Subject: [PATCH 17/31] Add __init__ docstring to NoiseConditionedModel Document that passing inverse_sht switches noise from gaussian to isotropic, and that lmax/mmax are required when inverse_sht is provided. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/stochastic_sfno.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 3dab556d0..479d844fb 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -74,6 +74,22 @@ def __init__( lmax: int = 0, mmax: int = 0, ): + """ + Args: + module: An nn.Module with forward signature (x, context: Context). + img_shape: Global spatial dimensions (lat, lon) of the input data. + embed_dim_noise: Dimension of noise channels. + embed_dim_pos: Dimension of learned positional embedding. + 0 disables. + embed_dim_labels: Dimension of label embeddings. 0 disables. + inverse_sht: Optional inverse spherical harmonic transform. When + provided, noise conditioning uses isotropic noise generated + via the inverse SHT. When None, gaussian noise is used. + lmax: Maximum spherical harmonic degree for isotropic noise. + Required when inverse_sht is provided. + mmax: Maximum spherical harmonic order for isotropic noise. + Required when inverse_sht is provided. + """ super().__init__() self.module = module self.embed_dim = embed_dim_noise From 47c420c9eac44c7fa400418bbe86e0fe7efa1bc6 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 18:16:09 +0000 Subject: [PATCH 18/31] Add regression tests for AnkurLocalNet and LocalNet Replace the regression coverage lost when HybridNet was removed. Each test uses validate_tensor to compare output against a saved baseline, following the same pattern as the SFNO regression tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../models/conditional_sfno/test_ankur.py | 40 ++++++++++++++++ .../models/conditional_sfno/test_localnet.py | 43 ++++++++++++++++++ .../testdata/test_ankur_localnet_output.pt | Bin 0 -> 9582 bytes .../testdata/test_localnet_output.pt | Bin 0 -> 9412 bytes 4 files changed, 83 insertions(+) create mode 100644 fme/core/models/conditional_sfno/testdata/test_ankur_localnet_output.pt create mode 100644 fme/core/models/conditional_sfno/testdata/test_localnet_output.pt diff --git a/fme/core/models/conditional_sfno/test_ankur.py b/fme/core/models/conditional_sfno/test_ankur.py index 58b042953..7f99fbd08 100644 --- a/fme/core/models/conditional_sfno/test_ankur.py +++ b/fme/core/models/conditional_sfno/test_ankur.py @@ -1,11 +1,16 @@ +import os + import pytest import torch from fme.core.device import get_device +from fme.core.testing.regression import validate_tensor from .ankur import AnkurLocalNetConfig, get_lat_lon_ankur_localnet from .layers import Context, ContextConfig +DIR = os.path.dirname(os.path.abspath(__file__)) + @pytest.mark.parametrize("use_disco_encoder", [True, False]) @pytest.mark.parametrize("pos_embed", [True, False]) @@ -157,3 +162,38 @@ def test_ankur_localnet_disco_kernel_size(): ) output = model(x, context) assert output.shape == (n_samples, output_channels, *img_shape) + + +def setup_ankur_localnet(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = AnkurLocalNetConfig(embed_dim=16, use_disco_encoder=True) + model = get_lat_lon_ankur_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + # must initialize on CPU to get the same results on GPU + x = torch.randn(n_samples, input_channels, *img_shape).to(device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0).to(device), + labels=torch.randn(n_samples, 0).to(device), + noise=torch.randn(n_samples, 0, *img_shape).to(device), + embedding_pos=torch.randn(n_samples, 0, *img_shape).to(device), + ) + return model, x, context + + +def test_ankur_localnet_output_is_unchanged(): + torch.manual_seed(0) + model, x, context = setup_ankur_localnet() + with torch.no_grad(): + output = model(x, context) + validate_tensor( + output, + os.path.join(DIR, "testdata/test_ankur_localnet_output.pt"), + ) diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py index 886444d53..307b6cad2 100644 --- a/fme/core/models/conditional_sfno/test_localnet.py +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -1,12 +1,17 @@ +import os + import pytest import torch from torch import nn from fme.core.device import get_device +from fme.core.testing.regression import validate_tensor from .layers import Context, ContextConfig from .localnet import LocalNetConfig, get_lat_lon_localnet +DIR = os.path.dirname(os.path.abspath(__file__)) + @pytest.mark.parametrize( "conditional_embed_dim_scalar, conditional_embed_dim_labels, " @@ -255,3 +260,41 @@ def test_backward_pass(): loss.backward() for name, param in model.named_parameters(): assert param.grad is not None, f"No gradient for {name}" + + +def setup_localnet(): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + block_types=["disco", "disco"], + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + # must initialize on CPU to get the same results on GPU + x = torch.randn(n_samples, input_channels, *img_shape).to(device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0).to(device), + labels=torch.randn(n_samples, 0).to(device), + noise=torch.randn(n_samples, 0, *img_shape).to(device), + embedding_pos=torch.randn(n_samples, 0, *img_shape).to(device), + ) + return model, x, context + + +def test_localnet_output_is_unchanged(): + torch.manual_seed(0) + model, x, context = setup_localnet() + with torch.no_grad(): + output = model(x, context) + validate_tensor( + output, + os.path.join(DIR, "testdata/test_localnet_output.pt"), + ) diff --git a/fme/core/models/conditional_sfno/testdata/test_ankur_localnet_output.pt b/fme/core/models/conditional_sfno/testdata/test_ankur_localnet_output.pt new file mode 100644 index 0000000000000000000000000000000000000000..6f2db9106ebd0314dc488a8ac5df7eedf7b2c5dd GIT binary patch literal 9582 zcmbt)30RI>+xF8u3KdE_8Vu1SGW4u-kx)n^l_8|LNVCjiB_gFnhA5p0eNAIE*K`&`$$?(1Ib8csuf9$p4RP>}II znj&Kr!^6hi!`#B@z#%ttM`tSwM<*K(bLT@IE{8nStt~t()LafY>go{2uK)LThTF=+ z+0Dvc&HRvugQL5!xtopUAqPima}OIQcV{v5-NuO7rM_))^+{2Y$$>2Sg z{(XIMeJ&4Ifs0mF&W?^YRvr${PVT~6+^lWfY^>KhSa}G$3NA2m74mQuHgXkFGt?Il zG=j$PBM$PfZul>?wveIocUn`kb#^;w;bHD+i=3aRmPmZ744tuHMekduyZmhqx5g*{V$7ibnx(SwD~224{ev`=gj}B|7#iK8Onbr!!N5=U!k$ zOo2a^9J~Rcq$em7szKWEE=Xt(;$v_Ru823HZ*c`$R+m7!un+~)its+53I^IOXq4>2 zj)pe~wPwTUS{Mv0{4wkhjwQ^iU|NM^^2aiSTQ(uBy9XlGe6+Jng!;xz;HJ}C{GL#X zxsSVHswzT9+9x0zS^#dN7xlB&gpKlr zSpANh_8SONt+TxJV8a7k^Lq{l#|hM|5TN&1)vyd}f~CnK34y=ghrwNYXs$^?lF?!ap|!ilDpb)<$f*t10pd#5d(waYlymd0}WR*uhUSa zG2y3~$4J+M|yqw-DQKjL>ydd+Oh12gr z6|$q@KQj||c6KAL@EzHsSB)$BF?gr22h&)D&YPYOnY+=*(The`?o9gIJ|P-Zqe*R@ zny|*C31Sw>*!Fr7qCy@bHDUn#eFMn7`3aW-rqBk#4A@OXgKn8b7y7GFt>dC}3vU`O zu0D$s<2GRQ06L&UsPY?yK$_To3O=57KU`mdtx1q+4T!zi?8!=U{M+IrIA z<&=vs#S**|D8@_CJg7Pp)GV);D)p2fPHd!V0`dpVAOisA9G> zIg;ch7YW<&88i5#GfT3bi=A<`kyWici8%OGlG^>DWXkz`<{~~8GtTufS>iqywlBQ# z*rt<$=6tg z$4U|~yW&CwJ}V%N3*dE#C*vkHXzjsN&hxksD7AKO%EjOloTSQ9Y-EN z-$gd_H!)v--bPwSxiOY+f`-eUi16Huvs;CszrKoS?-i%VrijtCHiGorZUz<(PbF{0LAC88gcQW7|MfDQoJfF>b1{x-N>doMpsv6f zI_VG45n724yV?=AgPXbt%%W$K22uEW98oVm;R{r#V{jghioZb<_akgf3diHoLM)a~ zM{MH)nqr)fMa!m9sjVisawreQuO1;>US7h>#MaWPU#!b3ftP_Xs>Dbr?#G!SK3a^j{dpj=C7= z81vC-<0Z(p`-o50-H6G4kBRzb?D_r@rY*N{h4&HXpDexL17&cS)`m_MKH9!YjP8i#rbU&rXWgw5b z(7t*I;%0TQeVquCt`Ha}T|&#YD-b+)72E6mAeG^R{fC2aZ&o_G6dSNUq8S^>Q&g>d ziy-3>RGytcs?8`I2g=YbA%j5@$IAV=fbff{Lpm`M4%a0y`Xh=MXZYbt-9=n$2!^8H zRj53Rz&>X&80$yk;e|x_x-#f2xyjW0ngpF+9ERymZCKykiLF<2AakV=3o>~rzra&m zXI#Y6+vibeGKS&fchS20H2YS20Y^7au|>rX!lr4+2=9kvRU6Dk)gh#t3bJMf-CwOq z!;|l$V_zsf40&Om;1pO#(^v+b^_SqFQ-Qh;0ovE= zjePwU_=X3N129tr6r~nS7Ukicf2%uh?DCQAr%vdsR@VCJ5&zw zISB|!Ho$3J1BBFR(kTPdRAOHSj#xZ{rv5`Pt>0s>(Q+ExEkYf)=c4NC0W4`B#8~HA z*dAxl#J*VQGnv?7orLV4$*7tg57nQk*cxAhnwiQpy7?3A*cl{8vj}=aQ>ocfS*lon z6Wr>vF{xjOemClbw|Xb?esMPaEGQFGBHL*KNqt2Gflr+uNE&xP|l9}@4+%P}X>6K6UMk*ec_FLDR* zpx_L+9tGliY7p{GZ4fto8xkZ>61fx?)=wb`_D+*-Vr=7phy^F0c5f@jg`Kr@TV@bX z&2bV`EQPK+9pr7CI6LR_SY~3`06aP8>Ph}nC>V@KhT#G*WfU`Qnp0WE@rMvIq=NoB zb>^!18@1l#r?5&AA7*wx_a)JL1Cg}%7BWXnac2F}j0S3kHzD`QllKAS$k$}1?81Yj zVeW47syvt#xg`GtT5W?^$HsY%Uo8JdRWOzRYrq;!K}UJ!H&%3u=92$yn!Xw&0meR$5{; zOaJa{5@nc9(##?i3~KGpzPyLC0H+=4_Q5FvO_Z}Q#f>p^-c9O zNinY?EHxX}*<~_p(_St19YcGv{#zu}&-jp4lcllKcO8~HEr!+Xb`o04LNx3iGc4;CiORc2Mn)x=`Iik>w6!Rre3k~DyUoJ2x)nHjRG-{C zb(pxgJR(VBw~4d!4`!ThQRa=0Lag$i-dYQE&52u7G~xRF0TGK%B?C?>@D&V0ylys> zNAF^;{#x)yWD=XyBs?_b_=);eU~hVhvpVlkU)T!WCCv~{s{_O6In;8qU{;n1tw#mu zZ?1)8#(OBN>&3dnR%jl5j@bh#;OTr1?fqY&wzdOOPGxu}*$9gqPK;{8IaA}~aX^@j zgb~gjT;2ycoi8{-I-$Wi148H3;+*XZc*L+UYES^Ne4o(mlN|x?R6E z+=D=|4n%F-1;>fK2;fVG!AVcfp3sIZ_Z8`BVNt4jx&=mwmtn~94Zf+*rKyJ)wBFLu>GlsTboI=k6Mx=*1yr_%Xbj!)rG~|(!m@`#TJDaIOV>^o)C`jAlHZkhi<|v zItojbE0N+LPp9}XDC0CA-DG$Zay7pAsrUgG3ST1f7U%59`iLLuwODZH8u&Q-en-J4 zFitRNR#*u{-bm5zUt<}~??It-1MUXp zqhQW6>|dP^^8@v;tLo?YIE7H1*?{GA2#Uh*k(l!a(yUq-m5;z9S(47TXVBJxSEx-Y z$NNdOcw*lTB07$=J}&zF@^?7z>q5-xdYCDd!*h87e1_Pt-}?-mkDFoU^A&pRQFt{E zVs6YkOgmc*zr97s5M`mvE)^$MW+14r8V>rskch2=HCHqqwuGSO_b8Z$JcrqxPO#SW zV&(HU2oB4`E}=xs6->e4U;!BQT-3HmmDVZFqhV)r(aG?Gx}7gJdS_z~S1C$vm11|^ z2RNVPqhnLPz)YYTrD{WP;S!|>@4n~wRs~3O=d9^)9qG3xsCJd0Lk_ld(RCTx&5578 z-lRgog%f9ue8LZXUb=ha5iEQjU<+rhw%J{X7Zjl@mv5wJ_X*Rkkyf}~x5rk!+mN7# zF+1;fc!o@&&&1AQ*?>03=F73TDhS=9VIXd2pc9;lq=ZWdsAQu^<{Wwx1gTY8AhrJd z1k>Lh#&`(ypAPVij+h?n#diq;2V>)Yvc zSfP}xTPukBdp&4MwicZ%uo4?;&SQ^J3Jw`R!A$81xcbV_L8pGC9s*LiZNYza2q(=a zQQ0OHdZlYCLN?vNxWyaH9Phx|l?*EG^#S_aub}#li|*nI!&8gfIJ#{T9iA>mcSla6 zYm6o6&aPYN*Hp!{!)_>-lBfP{i|L8JS1>y!Mk7UxaI5qX>}%7YVtWs}WafrLSw-N>`Azf9X2Vr#)-f0UE}21GV+kNs1m8djY!(wgwl2Ipr1Md z=R6_WEzQZf2=^dl;ua~%9U=1ZvM~Ryh;{v9C||>a`x7&u^jr-~=FP?m)o$YH7f3w& zcN4XDF(igh2!cv;F{gJrr1&byaPu*8Y(_k}ZJNf6Ibp<1Z-^tWPHCXaUIWhzKM~JE zz2t%FOOm}MpPcznMygytlcxjEN!Kl6-IoOEQ#oa?3H=BY@>6c z>>aJwGxI*%K=b2!@=@jnwyau<_R!jl`(LDR`Gho%(o=GP@D(?*l%DY*yiXD_g{o-a#2I^X3vm87 zg0s`3w1RBpQLB)N-0OPCTdK+)cB&(Mi|twZR4^<0@JIG&z*8uDC?Rqy69^Q+glr@k z6%}VLS8mVPwEcG|SKh%CM_ZQFeJ*xUSQaaG>NEWC(8hXoWjH_n&f25Bi+#v%GMn8g z2fa3NRz+wTvHS}7-&#i|XHU+Q5Y=Pe+suXO%BN9yrxKg=lVQ2PnOP>pU^(_uQok&P zH9d5>7O%otq8;W7oegtf!WO{Z+z-rz_wt#|Pc~;Bd*;OIGb_pT9nT7m)H$dTG3yH?*}I7-x6D9<%RC%ToB^3c zUW9XfByW9+32yk4x69`fQ$BtYGv9$6j(bM*F0>OJn=G<(tvi{wXf|2GSCJ#N>I)#4b#MM#4 zV_53*9?BLxbkiq(TIG?5okwC}zM}`X8#wTzDE4UNX#@YZUCc6!z46YEg(tI0U{)$d=jYTzabq6hEyvL_vk!umh9su-GQP?yQjufP(3??? z&eyMSmZgg%ov*kPu?xzvuVFye(6PG_ATQOZc$^YVnEC?gl08sQd5ek_!8m4@k4X)$ zQTASu#-|o=a?Ll9>HP)wj^tw^PnlY5NP*jLUMT0DODiu4(i69NskYX0SRH1eYN{L! z?vkVF^FH8>OFd@uO{ETrGPFIY6vLdee@A`=B#(2^f~Pf@P|3%WN+GJR*o-HcZBPt| zh11b%Xjpy?Q@2+kC%GSDis^9NQHbludtjK}4y9v7pgy@UALC!ShQkye4yep71vGe)@=d zyFtvb|9}R^dPEn!z~&Qr-HIWc3y2H-5qr#SR=iRf`sj3N$GmMStxX)VO9q zMQ|$BkX}M_f2vTg7(Tj0_%7Vb<{(2~2#vpO#5<0!u`4GI@mg`Hm=}we&CzHpzYPa3 zPENY59g$0uQNGU~9*O}d-w=uAifmW}@Y0};$+SY4m&!T~W96Aw*mTh!B8M(Rj-!uT zJ<(7vih{T2O%$2FKuLQf+5>X1!T&5o1N_ixk%ob7UtuH@2GiyqoZrQ%uk4qn2P|h% zE6-}IWAwst=Oe@_Pod|o55Q#cZsb2&$;roWr9qooa0-_&V}An*k~n$ApZZkYS&3S4 z{wcnQ9xUXXL*vKh(ftb9n6mLUHa&lgHM851yqoin=)^+9|2`+@zLmx<6QcvSE}^rL zhfcoCMSHnZvFPAYUj@3Hn5W&vH zCDAl^cXi^BXewrB6>{QTNh-A7hskCHKz9D?jOb6kcydXSwkTD@LvR6QteHetY*<0- z7M8=9Q^z2JNeDB20>9sC5&a<_y^-D|Z|xL>*OuW#@HqNUhvH$52rU~-#LxYjD4Dhg z0poHs!Epe!_TR9B$AH$TEumIHFL8<%K)9+8Y?TktD9(XMUpmYTGI3-@K7wzXapdd)sV8Al zV}(s+B8W^MWNv%8GILWwB{M#$iFoH*;mh{B*m(0cplOTST4GS(@s9v=dLoBoVuepA&<3X&n)kWJ%jS)4H=ekT`W~!}j_; z=%^>6rtT(Q^6o&)zON)`?+nZ`n~4o_4+#IrG_5sj4`SK)RcM;&p@;+$;cQi8IWVBE zp$k0;FOpGwmsPYUfykZCCnJ3e;4ogeV*hJ1(*LBBygn#}?|TljkL*51?)z|jD8UU7 z(OZvKOa=CgnHFsEy&kOHNd~MNt3Vz+8D%9pOpqGA-8fNv9ZLn3@HRq~wXMjAJXphp zD%S_>E^?l|W?=+IE)0<=$DEmd*&9e}%vE-$VOhpehYM_%lFv+O$4C?&-AtnEJBd+A zG>SO;uJ1-QWMtG6Vg6C3<%6Fjtk9KpGt>>w*BmBgUskcV$|VsxBahY5WelP5SH#tR zoE1Op2ha7Jm^7V{2hd`^hHuL=aWU(wsArAFGxH6R98N83zL+ z$=7xR{D>1mh+QE$SpJN(I;N8&i=s)Dcrt0#XeKhPf{=5YjAtD`$=bcm5Esz z$fF$)khj5d+YK13mcSLO8uICN7Wq}#zO#SAV&5zV<5&OJm2Cyal)tNN|2N}6^iQSj zCBnaEczM^(@v=oL7HMj6?rCW(TdK84W9gzrEB=$_P>}nrR*6&H|J4ok zxqb=!-^`UivNIHaJ@`}mU+`G|3iwA(#lk;jv`M}yG#6P`E+{%`R9@1|;~&(E*&PncOe|L5cU;vW5K|8+Iyto>I@|K}Pw@BQZ^ b(YgP)&-soV`{esYMB%(6!1y)(A7lRmn45h| literal 0 HcmV?d00001 diff --git a/fme/core/models/conditional_sfno/testdata/test_localnet_output.pt b/fme/core/models/conditional_sfno/testdata/test_localnet_output.pt new file mode 100644 index 0000000000000000000000000000000000000000..70c5cc8e2a55c2b4615b37dba280cf1755f1c8ca GIT binary patch literal 9412 zcmb7q30RKJ`*ul6g_0$tB$bLbEvjeEsidOPMyvKs+E=Nhg%Vm&$Sx#H*;9IE#A^>B zOL()dWy_xRd;GTJ{k{M9{eRy#$1%qo_uO-z*Lh#}b<908*Ximg(OXPRT3YOXl#ybR zVkzOtDSBQka|@xdt`|Ll}W?By*k z8e@L?}cE-?I1EkgcPUH|s2 z$B0S)^;_#j#3#iDrv#>jCnbxF+1d0K^A#1tU(knN|DOMg{IB}=&{|?Te*^t5#{>q) zL`BBMhR3DY*=YU){g;LRy_g#RB0%RPKJO(m{;U2y0TnT$za|heHzhp&m+$Os^!^|1 z{+m@yR7y%r_%99%$-`lNMD~BxzvnPQ%<6tN&xSy zvLU^<|3-pN&Onn!2eJ5+A#`dg<0El`9rmBatzXFVdX-|b=(rB=S+C9a&6R?MiwZRT zlDWl|TkO**f+3#$IsdAHl?mBcuOExUPd4)%0|R(}Ye~$|ux1JAsa&ObJ8@qfM5snF zJ8!U*S1XU_vEH&6+4ZSuRKIS)`G}po$2x;N5mfM~MWgr|+YAyBHVSVC7xISa+w9}n z<$`-TZoH!37nX3iKkpqbi{4)c3udcYAu&)5Rr^axW%nBX@}eqn%ue9hMdO5x_ZPzL zl^s8N+ZX4KjpV@!aya*MIVai++3iVJNN2Vs)AzOI1M1}nGjik)1X9Arvfo**{!~kY z(*C^3=R4afHi6fi(8E0)C1FQqF+c9?iZSEF`O0UDdE)6?B>T}Dl6+eVAu}gI+-5!hF*FZ1pRMEF=}|oMMufok zU>vcv&?nymhGJK3ECL{8LB%84qpD)=CK1Q|){G*@J}pAaEp@@GgHzdsUSEk$Z++r) zyoHP^bml&}Yk0>aJIECjahHbmLd#A)$d@GHNY_C!F{XqcI{$;XiKnyvdyDvoxO&0G z-EmyEHjf_~xPY4~Z6_5b`LKDqmpBJaL^>VMGE!$+<`9Z3sj09#)OEy9EEjx}2U0EN zner12!JaFULZhvB*`}WH=zU3v{V2M~MuaUVt7=vVEV_>~!|$rhytog^b=G6a&&v=V zl8j~IyU?|uiVVqo#5vI}cA!1kO1JK3?4&GB=Dy2e~`;()PG z%K10Jx5sMO(rAgq-zwPk<^p!LP#wp0s@UM}J?wLR0^3vKMkd%U7yNc5oMmJQg_qYV zv7~9W*b`nu=I<=PrOaHM@^~V6^YMJsD)UUZPY5USc?!5Za0NVLw+fc#EkjbtZWi8o zzcJh+g-wpz&nnlnF)7VFJhS_a)eP{!XkKY?qv2kY$%rgy+Bh;(#UfO`-pveWca!7W zD0{ymw<)b?6;l&4gZ+ui%pk;=>{)+Gs~;oP=|zz8Scs*H zu1wKs8`FB6%bwgi%?7RBNDPCgvz_+CQQITW0^HW&xW{sgPW~jkX%|P99nlqDNs@J+x-SUbBu`7G-!Dq`*8!vJ(!JX)uP-;b#v!bX~H-m{uLJ@b_5J+osS%7*N#OE4Kz zcZt27ya9__!{D*IhlKUp*wlT2l0mmF3d2SpW5Yn5M-O1xnGKklD@`jL zbKpL2IkeTKX!+`z0SZg~85py(<)oOO*`t zrPUP1TdEeh=P=;K;*G+583Vr`xH zEcO#@@o2hPX9rI1wWANk1>pN~aN%MCT=N|8!hA3l#9QL@G7Ea_v_khuyH?8R-?=(j^-2`(I@3wX7g%7VEDzJ$$P-lc z`JT2&^cm>P*Y><5WdX@JP_iP+cS^o8_MRi>e0&Wkf6_=Jq!d{*~6W?6rUjlJT=(HdyLA9+%3BV4Lp1ryl+xJiI4_D=B;@ZVQW=?sa?M z{*y{)nAH)sm?E}q4%`K!U;t_wWAXQ27zMwTAzPr~zapf-59ATIV0 z`DqY`mf#$Iy`MiZDW}-|HJYsWS;@6cxAXM0PbBQ|8fGwc0YB(g!riKT*wC(}{Fy~G z*H?ValHV7iwQ(NatqK)tof*nczCFbz)y#s**g<@csTUukZ;BV=1_+i2M&hZiF7}V` zVqKcQiF`Uw@MPR_^8ImtoHg1AxfTKAU+2JUt`zCILWs0VD0-G_B4m6KZeG+ws_s?c zj2?A%t8+4Ia9@M(?vt4~X=1(8WSN}pTA}+xPodfeFT`%Wf8@%O@pw>Q1&y7|a^2yx zOd+{mpi*{|)lC}=r|>vo)bTCk+n0_bXR!&d?ii49+mHlVP=hd5G;JFCL#y7L_iYC??dqfz0MojRdOb2Az z9Wq}#7b{-hWV)Yv5sf4*cIALN-aJfTt99mJ%lZn`_k0tkOS(cUd^5Wos>z117COAg()#Q+MA%8fIYZkW<25 zvn<(0r_E$x>l{3Lm4e_evB)2$$*caf6rNjNf7!u=*$dgA zE_+72stE}{-83V2BlLc|%3^X;an$W%(^L5i%t*5yIyQy zvg^m;sP9fh*GI8W2`_~m?Tv_N(1BzXCC#SGan0f*`SigD7u&l8y1Pya4Su+@TT(*7 zxaG6iV(GrD$e^3(d|e86+HNs#x;I|W@T93_c63<7NrW5ZLba_OW7b{7In7A;M;p?% z8Jg&joB6T=l{RWk#&k7RDpOjFmO*Y*yFd?T_79`=mJX=hmXAth4-|Y(LB5SK zYF8e?EBAr)>Iy=eM$F zK|5(GE;lcLUQIcM51T@J)z1Sn3!-5UZ0X#K=CHqPkA81N^TV#A7_s94diLnx{Ao9| z-v>Q4EEPA{EVSmy)9|#(a5@o0^|q?hCFTiK>&j_3rfE{WQTjNdGMkoedII^?pqt7~ z=;~MFG5S;}J*J+I>C^xZtRiX7DJS?z z)2Qdt)yNi>;+mo!>^f)S`%lov2h?!wOBgg@6TR9UowpEB!EQ8tr z9o&9zfwXTKC^OiIHSe@wsAK}KHhXBw-N&sFr_ehvk|y|Q;Ehu?8tu#=Z*maj`&NiJ zJ;scpW0+uJLV0Wp9^O$vak47pEmZKX#019kG9Wp~5O>wj!cIB|o8vt2?57`9P(1|8 z16I%&u>f1G73uGL`qA9q%~5Cc0`g8$H2H!d4GVsZ=GTsNa}cFIN7M1n$Q!2u1fq39 z1NyCcfPuz8v8UVwhZk32y?F$^GgBWApPgV<5h|og<2G|xn8Vr&7V+h2`$*WGI4&k{ z&J7NwV&BOQ@~)Q?WWBUe8`Uj%yWlH(YNLdG9$9c18Hv_31x$G-`cG9H734dm^ZVa+ z3BtSQLEjmCOj4Yd_pR@Vb=@?yqKLl;P39xJ3{XG1jkvEYK-Pq4UZPuTd3=x^Ulgy73PCKZ-*Au} z)_3A#t}&wJ?7&^GL!jnr3s}YTj1nVcpWVmj_Oi!_kJ3C%x?JGz8UmcmBE!Ds@Urvj zSbMpg={|pJac{XQe-&U)x{dRBYqCCzQa6OK!wnw|X7G;INIrS#JLcYgf`m<*N0$1G zAfb{GyrYjSYAmw(a(@A|Ugt9#EiGhqz7h^J^ye{~s+q5=I8s#Nxc~GDvg2wd%pINh zmW}E#tGCC78-C~*{gjkUtKbulmk^8E0qkDBE6Yie-~%p?BMq&}eADeqWGWZUwXf&# zLk8*SnqPp0VnSwk!xN>!(cI&WJRVgDn0nzN(oRnCr{d%J$d{QY*U$q=m>~R+w1;Om z?jmh%a#(XZA4L~(xbIY3Uj1|dzq4Qw3^eU{js0R?xy6}l%B&!T?bSr`yfLdQ%;j@q z-T4Z;eeC@CaXegar)BJ`aJ-D(CVcSYE(zH^f-~P&;eV?2uyToGV`eKTm4>2Ayh>ekKXVufxShP|)3vPSJ(m$EN-gg@l zHq;8W@}@8+&q`*Z^^Lv!euf>R)8O3chfiTkgwqzj7gUegiFXr=aCSgBJ2Bx@(}uh% zM$9>j?-jugxKC%3hZ_o}&si>fdwUz{srW3M)~wnjZS;X$7_@+0RT7=05P7_e-NstP znwf@!z91}YH&zIROet>>Iu~wZ<$C+rCpAYlSG0CqQ+m2zy8W0qn-ybL82M4a9Qj zJ)utjE>b47fN7MNiuNBLgwo0-q-F2#Ox;SBWxk#XYvqonHP_?>3-TNh@_amN?li_c z(fV87{|gzIQOyFzZ-?=Nk??b<7G6tU$Yx)C!(MDzh4KgU*`1+vjU!@1gg;(4;gO{U zN_50=TwaOPI87#rKTncAqWNw7i%-llv5R$2I?lYEcVJeA49WbY$L#&;n>MFRLFv&a z!pA-R1?MO4W{1^`$doX7Yzh61jWXPdspmqFmtTtI4vX-ixruq431rL0;nXpJHD9XNQ6JK9k)QgE(a2I zGilThFO&@%hmaYg>DklG*eBsbQ#LL}X?h1D4sJ!wvpg6F9l|b)OZaegB4m3j(owhG zBeansS33=Ad#%veY=}}1Yr3_!9vxT_0I?klabCe1{$3Yw+kPvY`>%paFE60u2plAL zz1KrqB9RLH4)}5+?7&u4{_) zadjF@s<(pI)j~Z@l0LAkf=^>VDki=L{miD|T>!zk>*2J|aeex2cotf(yQ5`&GEQji zfcYmw`h189Ha}FN$9%>hE65Lh1%UFEeb~`ZhSOEE;F9GA!>cvOuCT%jF>5HF*QX5& z?SKkN?AoIO$-`zi_(Xv|tu};9qcxb%A?$t;07K;@cs-D$b^A7<)XErbC-mqyOg zxj=Gm61<5yG(A1(C0S?qKLr9_RDeEFrr~8Lak^UzvU3HfQ{IBB@6D+7Bui|_>W}Zg z>Ee2U3RSlXpl=6K>NqJ79fMTq3O7UANB9+$JIe9w-6AXqYrx)!k@V2BQS^LEe_C9o zNH0%!!1zja3_pAqr$zt41mlmeed~op**O^Su^%06TnU@_c`&*hL?6d4h3DgBsLu_c zJ;A9M+_)OWv4!Y&Rh1r^*?^$iFne&kkT+z1;g<|a=0v2@Hz7>-jm&)Det+e}cZ z$fccj@N;{^=Z)6wDQjh%DrrT^Di9 z@+eGdOJcLShazMAYclWRRG$6Rl!ru>k-#D=ZY4SsO%0I7kycYkZ;Iq$iv?USGaHMK zK4Ee65}9AHj%W$;&=4^iY91>HQy7E7NuhjQ)+44UErZo^dHB7=oIm^N$Ti003FPm_ zke$` z@jXnJ8wRaF=Zt(lCtZ^Fz0}0b=6myPc40h;uI4RY`*RKJc>YO$0eP`&Irr}9#bo^Q zxlk~KbqI8jdrE-^tyo6FUyX(Kgebnh*K~d_T!$SUw3DAMHsaq+qqyy+Lg7A*vCu93 zKwd}8;xlcHxLw0zB6}#8@5voQxK9;N?%Ge})b)92ekJdIokxzUOy%bz7vh>zId_wZ z<0WJTt5opgvoqz%^+7T)D2PV6uM>}ca+@veRW3{)I+I5`&qwN+#XR1?3~$4iV25!j zf3esLHp?uytfU_ge>R)R30(NJgaH2fXecjlEZ|rh#g9JC=ik=%!Gr)8f!ekuT#%r| z+b5l4W88yye0DuKI<;63QRu|g#u=i|XGJ7MFGE>L5U+}JN5-j6((=*_pT55$M(h=l zG-u?*s`-5J7FpbVF3T&pfUCJ2W{aA0x!#>Tb~QB-I{wmpf@C?7vkc?jodM*uXs=93 zF-v&SpJyHS<=!EsmQn}n1QAb;2(6l=aXnoV$4p)EJh&FKWjuu@a|m)TWRk11*5mAV z8^PRV6JZ@MfdTzD5_`)n?9852EW4P2(4$k?*2WF&eP9^tC|}F&gsp~VNGWFYYG5C& zWDw~S0PmDw;*eX$IwO1$Bv~(rEgnbaZ!v&#$Z)}yb{kguerZ!z&@Dl1@Do;io-scQ z3jGI{ELA+i$jWd@CbNHqU|0TpJUlU<&DMz)s0I~an@$A_J|W3%yWR;?1}?)oomzJP zxEo4E-=*-!M&Y!@vP9YGG`o05L11q1PHLlAn$N8-Wk!>o_z22`vHA}{aufs@e&_T&6Cw)IXwyc7A^WcmWe#J>_P ziKWm`kz`NgxZ$QI@MB$*_#V{t9aMoLH6p1OqO zUvkIzNdwTLbDWe;6k~6!c9Z%#N9NeiPgrv$3gxfESV^w|#PmQuyn3XWaON&jU>%I3 zRVAb@(H=1qc0x&5j6=R*tZm>f*4;o6V;u&5fxxIm7Kv6bnRw+rc4f|GfooJMi&IY$ zPND$<&!dKx@-13|NAcaHF5iSz>zgrK({s#S=_o5QbAjGcA(@u$C-gNLfYCRdS+aIB z@$DOl{=5|3>)cH2z=(GTup6Nn2 ztR|TIq7EV7W1%RUhBv(qakDi&tPc1KtYQBtkgoFRg4`%)q&_Y| zs+KJb^^0I6ZGwW|Rneoq9!olbLqXp3T)s3d(r|^__KBE2TaJ!3kEiZ3i$yW!ikHL2 zQ%mJatV~;m&Yzz!)4c!&3nyUPN?AG}v~Ji9O!o+Nt$MFK{HF8sAt+Bx+GMKZVM?ulaU)$ySx%DcO9r-uWppeG@b*!>}iFQYnAB{lVEJ#0X(cUrbeS(U|OwCr;Qs- z`xW=0&!-)MPDTV>YpP4_I{j$kqkI&48&Iq6B`}aXh)vVHsk-QVf-B~rt#yj%-=2%? znhNZ)Q=n4WGm#eVhM|)^P+}uN3soGDf7=Sza>pV;X)4CI%!1O*Y^-*(rqY|WF-5;G zO&_I7yN7;>4>fdL!mBohqb!PiNih+TDk=mXh^UK2J4Usec@A-ROF^#{rp!bjDd2)DC zlt@945&TOG5f~L_XQTEP2LGM7|5s~hVrDea(A?a_&~T!Ov8nL{10z#|i3Y~zqCPWI z(}^Z#qDD047t6vHrH-+p?EbIn>L~t;@PCvee^ghD{B`5M%Kxe(`7_iX)f8j?3+AKf z>Mx=GB+mFV;veM^TK|RUCX#IYHRAu!Wc(TQk9r4DVEmz6`8&}6=q&z>^~VSs@dqR^ zF;9_~|B~Q8^%8#u`ePs|{}+&f+TQ^ED?I;<^v92_`(H@Q{|@P|!_d{SPoHuB0#K0{ z{~xGdilATR->uF+;z>+w;^=?&i=O+(t&o4kfXJz$sFVC9FA+T>CH8ClKga$bf(voi literal 0 HcmV?d00001 From 66e9e66853fcf980853cd3fba3b6755081f1017a Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 24 Mar 2026 19:24:57 +0000 Subject: [PATCH 19/31] Remove HybridNet and its tests This functionality is superseded by the secondary network support in SingleModuleStep on the follow-on branch. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/models/conditional_sfno/hybrid.py | 168 ------------- .../models/conditional_sfno/test_hybrid.py | 224 ------------------ .../test_hybridnet_diagnostic_output.pt | Bin 12248 -> 0 bytes .../test_hybridnet_prognostic_output.pt | Bin 7064 -> 0 bytes 4 files changed, 392 deletions(-) delete mode 100644 fme/core/models/conditional_sfno/hybrid.py delete mode 100644 fme/core/models/conditional_sfno/test_hybrid.py delete mode 100644 fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt delete mode 100644 fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt diff --git a/fme/core/models/conditional_sfno/hybrid.py b/fme/core/models/conditional_sfno/hybrid.py deleted file mode 100644 index ba96a85c7..000000000 --- a/fme/core/models/conditional_sfno/hybrid.py +++ /dev/null @@ -1,168 +0,0 @@ -import dataclasses -from typing import Literal - -import torch -import torch.nn as nn - -from .ankur import AnkurLocalNetConfig, get_lat_lon_ankur_localnet -from .layers import Context, ContextConfig -from .localnet import LocalNetConfig, get_lat_lon_localnet -from .sfnonet import ( - SFNONetConfig, - SphericalFourierNeuralOperatorNet, - get_lat_lon_sfnonet, -) - -LocalConfig = AnkurLocalNetConfig | LocalNetConfig - - -@dataclasses.dataclass -class HybridNetConfig: - """Configuration for HybridNet. - - Attributes: - backbone: Configuration for the SFNO backbone that produces - prognostic output. - local: Configuration for the local network that produces - diagnostic output. Use ``AnkurLocalNetConfig`` (type="ankur") - or ``LocalNetConfig`` (type="localnet"). - learn_residual: Whether to add the prognostic input directly - to the prognostic output (identity skip connection). - data_grid: Grid type for spherical harmonic transforms used - by the SFNO backbone. - """ - - backbone: SFNONetConfig = dataclasses.field(default_factory=SFNONetConfig) - local: LocalConfig = dataclasses.field(default_factory=AnkurLocalNetConfig) - learn_residual: bool = False - data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular" - - -def get_lat_lon_hybridnet( - params: HybridNetConfig, - n_forcing_channels: int, - n_prognostic_channels: int, - n_diagnostic_channels: int, - img_shape: tuple[int, int], - embed_dim_labels: int = 0, -) -> "HybridNet": - """Factory function to build a HybridNet. - - Args: - params: HybridNet configuration. - n_forcing_channels: Number of input-only (forcing) channels. - n_prognostic_channels: Number of input-output (prognostic) channels. - n_diagnostic_channels: Number of output-only (diagnostic) channels. - img_shape: Spatial dimensions (lat, lon) of the input data. - embed_dim_labels: Dimension of label embeddings for conditional - layer normalization. 0 disables label conditioning. - - Returns: - A configured HybridNet instance. - """ - n_in = n_forcing_channels + n_prognostic_channels - - context_config = ContextConfig( - embed_dim_scalar=0, - embed_dim_noise=0, - embed_dim_labels=embed_dim_labels, - embed_dim_pos=0, - ) - - backbone = get_lat_lon_sfnonet( - params=params.backbone, - in_chans=n_in, - out_chans=n_prognostic_channels, - img_shape=img_shape, - data_grid=params.data_grid, - context_config=context_config, - ) - - local_config = params.local - if isinstance(local_config, AnkurLocalNetConfig): - local_net: nn.Module = get_lat_lon_ankur_localnet( - params=local_config, - in_chans=n_in, - out_chans=n_diagnostic_channels, - img_shape=img_shape, - data_grid=params.data_grid, - context_config=context_config, - ) - elif isinstance(local_config, LocalNetConfig): - local_net = get_lat_lon_localnet( - params=local_config, - in_chans=n_in, - out_chans=n_diagnostic_channels, - img_shape=img_shape, - data_grid=params.data_grid, - context_config=context_config, - ) - else: - raise ValueError(f"Unknown local config type: {type(local_config)}") - - return HybridNet( - backbone=backbone, - local_net=local_net, - learn_residual=params.learn_residual, - n_prognostic_channels=n_prognostic_channels, - embed_dim_labels=embed_dim_labels, - ) - - -class HybridNet(nn.Module): - """Hybrid network combining an SFNO backbone with a local diagnostic network. - - Analogous to Ankur's ColumnDiagnosticSphericalFourierNeuralOperatorNet, - but using the conditional SFNO and local networks compositionally. - - The SFNO backbone processes the concatenated forcing and prognostic input - to produce a prognostic output. The local network processes the same - concatenated input to produce a diagnostic output. - - Args: - backbone: SFNO network for prognostic prediction. - local_net: Local network for diagnostic prediction. - learn_residual: Whether to add the prognostic input to the - backbone output (identity residual connection). - n_prognostic_channels: Number of prognostic channels, used - to slice the input when learn_residual is True. - embed_dim_labels: Dimension of label embeddings. 0 means no labels. - """ - - def __init__( - self, - backbone: SphericalFourierNeuralOperatorNet, - local_net: nn.Module, - learn_residual: bool = False, - n_prognostic_channels: int = 0, - embed_dim_labels: int = 0, - ): - super().__init__() - self.backbone = backbone - self.local_net = local_net - self.learn_residual = learn_residual - self.n_prognostic_channels = n_prognostic_channels - self.embed_dim_labels = embed_dim_labels - - def forward( - self, - forcing: torch.Tensor, - prognostic: torch.Tensor, - labels: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - x = torch.cat([forcing, prognostic], dim=1) - - context = Context( - embedding_scalar=None, - embedding_pos=None, - labels=labels, - noise=None, - ) - - prognostic_out = self.backbone(x, context) - diagnostic_out = self.local_net(x, context) - - if self.learn_residual: - prognostic_out = prognostic_out + prognostic - - return prognostic_out, diagnostic_out diff --git a/fme/core/models/conditional_sfno/test_hybrid.py b/fme/core/models/conditional_sfno/test_hybrid.py deleted file mode 100644 index fa1a61bb8..000000000 --- a/fme/core/models/conditional_sfno/test_hybrid.py +++ /dev/null @@ -1,224 +0,0 @@ -import os - -import pytest -import torch - -from fme.core.device import get_device -from fme.core.testing.regression import validate_tensor - -from .ankur import AnkurLocalNetConfig -from .hybrid import HybridNetConfig, get_lat_lon_hybridnet -from .localnet import LocalNetConfig -from .sfnonet import SFNONetConfig - -DIR = os.path.abspath(os.path.dirname(__file__)) - - -@pytest.mark.parametrize( - "learn_residual", - [True, False], -) -def test_can_call_hybridnet_with_ankur_local(learn_residual: bool): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=learn_residual, - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - -def test_can_call_hybridnet_with_localnet(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=LocalNetConfig( - embed_dim=16, - block_types=["disco", "disco"], - ), - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - -def test_hybridnet_with_labels(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - embed_dim_labels = 5 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - embed_dim_labels=embed_dim_labels, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - labels = torch.randn(n_samples, embed_dim_labels, device=device) - prog_out, diag_out = model(forcing, prognostic, labels=labels) - assert prog_out.shape == (n_samples, n_prognostic, *img_shape) - assert diag_out.shape == (n_samples, n_diagnostic, *img_shape) - - -def test_learn_residual_adds_prognostic_input(): - """Verify learn_residual adds the prognostic input to the backbone output.""" - torch.manual_seed(0) - n_forcing = 2 - n_prognostic = 3 - n_diagnostic = 2 - img_shape = (9, 18) - n_samples = 2 - device = get_device() - - config_no_residual = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=False, - ) - config_residual = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=True, - ) - - model_no = get_lat_lon_hybridnet( - params=config_no_residual, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - model_yes = get_lat_lon_hybridnet( - params=config_residual, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - - # Copy weights from model_no to model_yes - model_yes.load_state_dict(model_no.state_dict()) - - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - - with torch.no_grad(): - prog_no, diag_no = model_no(forcing, prognostic) - prog_yes, diag_yes = model_yes(forcing, prognostic) - - # Diagnostic outputs should be identical - torch.testing.assert_close(diag_no, diag_yes) - # Prognostic output with residual = without residual + prognostic input - torch.testing.assert_close(prog_yes, prog_no + prognostic) - - -def test_backward_pass(): - """Test that gradients flow through both sub-networks.""" - n_forcing = 2 - n_prognostic = 3 - n_diagnostic = 2 - img_shape = (9, 18) - n_samples = 2 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="makani-linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=True, - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - ).to(device) - forcing = torch.randn(n_samples, n_forcing, *img_shape, device=device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape, device=device) - prog_out, diag_out = model(forcing, prognostic) - loss = prog_out.sum() + diag_out.sum() - loss.backward() - for name, param in model.named_parameters(): - assert param.grad is not None, f"No gradient for {name}" - - -def setup_hybridnet(): - n_forcing = 3 - n_prognostic = 2 - n_diagnostic = 4 - img_shape = (9, 18) - n_samples = 4 - embed_dim_labels = 3 - device = get_device() - config = HybridNetConfig( - backbone=SFNONetConfig(embed_dim=16, num_layers=2, filter_type="linear"), - local=AnkurLocalNetConfig(embed_dim=16), - learn_residual=True, - ) - model = get_lat_lon_hybridnet( - params=config, - n_forcing_channels=n_forcing, - n_prognostic_channels=n_prognostic, - n_diagnostic_channels=n_diagnostic, - img_shape=img_shape, - embed_dim_labels=embed_dim_labels, - ).to(device) - # Initialize on CPU for reproducibility, then move to device - forcing = torch.randn(n_samples, n_forcing, *img_shape).to(device) - prognostic = torch.randn(n_samples, n_prognostic, *img_shape).to(device) - labels = torch.randn(n_samples, embed_dim_labels).to(device) - return model, forcing, prognostic, labels - - -def test_hybridnet_output_is_unchanged(): - torch.manual_seed(0) - model, forcing, prognostic, labels = setup_hybridnet() - with torch.no_grad(): - prog_out, diag_out = model(forcing, prognostic, labels=labels) - validate_tensor( - prog_out, - os.path.join(DIR, "testdata/test_hybridnet_prognostic_output.pt"), - ) - validate_tensor( - diag_out, - os.path.join(DIR, "testdata/test_hybridnet_diagnostic_output.pt"), - ) diff --git a/fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt b/fme/core/models/conditional_sfno/testdata/test_hybridnet_diagnostic_output.pt deleted file mode 100644 index 27205a1e2a057f0c7dcb32f750b99a26d2a25cfc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12248 zcmbVy2{@MD`tE0*=c3F~<~f<(=aEP%G*HoigiuN(rBnz-g9eo*k(3fCA-=Vgk}?$~ ziIfZ>DG_PF`P%#c|DV10Is5$fb=G&S>wT~HUF%-=_pIky>v^AbFLC7IWw8VWS^r1T zW?8br{I-U91w^b14fNgY7v|*~=X23kwX|D(n^Nw{BZtkgr#m-{!5sp-u`|e zTz~H(JTEBNJIwX(4~0Xx-MKhzj6!(ac{o>}E!#qPomGXyLim97!f zE$0K=j+Y#z9l63n1QL9Ff`fwme8K{QH*XbQ9O~;A>gPK*&?ihdL{Qr$L?|po*d;{7 zV2PuEV8=t@6IVFMf9jHdYc(Sl{I%5#)(3}f@(%L~_Y2+18F-$p#s8o2I0iTVL#^n` zn|*)&{Ga;osm)|5{|D6m%gA2dL4p38H~DQ2n`bNk|6pfk`iH>m?IFF~oWcLpe@{S% zrSaDU)Q@WXY%xMaTZF0SW?3jy5u^HrKgsQqD7MZ;LD~_1ne?BiB6&^6 zi0P9O^0}gq)HK=9O({s-FTW&;K{B*-RF$#~GT50J9|#kEMr zntbXz$&(r)Pw)AXWu8Jbet{afbHt7&`=pbDZzjXE20hh0ljtrlBvVp^=~;DSdUyRrlFlzhgG&=ZC)J9kPnk}Sem9^|>io2#bpy?G zm!mniVoTiXH0i5_Z;Aak5!!6UPv;-x#Z$$W^o!Li@}x_ix%6I+4w@by@n+W4q5BaT zr7T)sc9^Jzj1 zhLDFtpGm4f8M!VbOy6vrN*~oskmYjD@Fc;Ideyj77wZo4{FNd-?tY0>X*?z|<+_xg zcbudNzaVb+XVCV`6$C+^IwUI6N~|I7ZP!V~tXFKoh6d7nd=)6=cv8iq3>j{$AF0yCcmIuuX@|bb`4h$EU2z!Z*QnuQZefrx3I*4i zQvd^kgnN$~L`*pY8Yu~oEZqtz!?bO#%&tgL44V9iJgXUF zmdD?O#}Y%#g4vo7=W`ZT&Huz?yak-Rr3mgNC4kI*X_#}x7L(N)*phB98FNt+Xy2G< zX))#s4)3MV_eeDi_vgd)wkS9##6w+#_mtEY`+;-NcBl~g1$#sqiz^J?Le9=?;xZwQ z{X4p?*v2j3DXfng1ES#3odxRZsieqs0;;v@;f7Z{{5)@mXV&cpL6v4pJL?;;e%&N5`FLNfQ?_r;Hg&(<0kwPqz{Ec;yf*Unjnah zO#79n{)z5^if7;n@UZG;gd2;d_#p zTgi>z7G8q@i%r=7h!3YXp98;weAuN_2QtOO@cKd#*zGtB%OA8uR6#XdS#}jx{u%-| z&pu#VC%}zG_3)t96P&n57-+X0(Xl9)*EMofItOQDA&Q=!^LoD#|xNe=MO7w)$nynDyVna$?|1{`@%{ux@+w4NnKl_4lD0yOstg9+n}pk-6}tGRVqIqeD2wu9uZ#p{SM7(LtD3+>=o%dH zOoE=rTOoVTEPOUk41Zd3;m2h>sQa`Zie45$&&e!En*R%C>OO(>`;LNeIUnpaDTlh? zb1=B`3hX^Q86&G3z;?YQHil7n^Ii?Zud-oGDg{KwgmAO}0Wfj2!4YW+m)}>xhsWPQ zMs7YF*!>fHws7OHdJniheGi&RzhV5~5ZI4SMwl}KJAc@~+G&ZPI6M5j?3b`?^<|wpC=y8cRp_{$&~HL4q#+MD%o zYTqX~w_Jc2R~3PZg(f!sJOKr_wqd~?8_3upguGuC;;)Kee1p>1zGgP25*JL&jKIZe zPMFaW3+tCzVYBmZxacH`$xgdbS$P!>elY=3HOwv^wg({3Q-g+j5 z!Uj%QvBC}C7i`1$51cvd^ug61f^g0oHMD!3|MheRKsjz6d}|S805iupEn{VsJ-Z4+LDZh2S(B)Eb;i$Bj3@?fPpew#8Ay`OY}iPl%B7ni`f zdw9^oavjdu?166|Jb+VMUqDNyKXyzx4vjjRSaLxV?H@>BlfOP1=?0^xq8$#i)Ucb| z2G8~Rz~>JNSh!XW2eVrs!h1HZ+&T^If-JB?QJrL|1>$aQWn5hCgP+HI@%(XltQn9- zZMhX7&~A;d6xN~bt%dkLE&@+UY2w2&5u7*bie5uDXx3l|)p1U!!j@*nPg!ANSUU`b zt;CbO+tEA83TG|$Lfw&f5M(_Qy*xEgC&(97GFD+=XBuZccR-bvg(&FejVVs9j9O4b9&S3Yp@Z;`rNQPe;KgWGU444{!4x`DJxDNFivB2mMZAJ7L9TZpwK-HmX18Y$73(%)ajl=|T~nc6 z_D#gUkVQ8Zh|q7xG$@^4Pww=XlMTOI=+HnGGut$mIMz?1pRQ?8Bi9e)qKO;L8NW_e z&Kw}Rdo&4GbuBUf*hS7AbD+FCKas1GZ0P19J}RFr1>C&RB(m3uE|}Rwwks)7vwAJM zz5D}FSgegNzp{zWQ7ziGYASuHltNzH{vBp22r^*Ak95MVDzgPDBt zaPBk0uIMF-T7u-T;{>?^KS*XgA9Wb!qlW6+>C{qXx=F2t)S9J}_r@_qh-;nIFxM1X zXIVwK;`Y)jXV1c=FLHE4r#6-CokYD}8`5o(^T^m9ewwn|omg}xlL4M`GNN&ttmg5c zp~m90P0E%EazC`na-U42r!A&eni_~wax*z`?+TIGWlpV27$W+Whu+o@rt#~w=wN>i zx!L{O>dO39(%jWWmYVRJi#;sbqbyGE?>|a%MwF;~!kH3_`}yppIug{nDv4O2PS_zUXLa;ld7_+24L8}|cmiX~~?ZjAa#%V_zt zm(1%6_aRJI07spg!6!``6~BFex>8?wJ#rVeMH1MzbC}tsRt+m|orYEjWQH zCUq;}`o7zY&39k8u|Nk;C9tsR%P7Quyve@OwGGy-SHq3BE`eOT9o#5QBxM@gf%WVb z98Ww5H(A|`LDqfC=L%xTb{%K$lu*FW^*_M$kTB>?7a+$ry@o+=4JOJ$9Ih8Wfr>wS z#_{(|dFu+W@HhxE9)c+RDxb5SUxmYA3Rn_m3Ka$|%!=JDC7t1^khw`1Kkc{xLYW&F zyNmhE{&}LEFefmYfX za0|*|$MYYAm#hd#tZanzdVlaqJO|UGXFyT<2j;+oQE*RgWXxeZ>@VcRi>qQ7zI#!? z7cFab^Jxb&ToMcCRtgc<-)CWC+i`eddmP-CZU9;LaS)%(z?1!-z@|b9HO{OmT?+(Nyera~; z-fBj1i7?rHZ5UQ9Z-N(tc`zv9fFn||a8g?gbF(i%;_^h89URSc46yLn>2h%X6a{ae zN#fN#(zsTq7C!nVf$)bPu){A09;AEWf$VOWE5wZnmlnadY&EO<@=lPWWDL#iLYT~Y z2fr`2fPVfY?BdRcCPozR9LRyX$yG2hdIBzQ%Z8X)DrjUf2~WEJW;Sm=2AdQcAvoLw z1Ow8cWhe{Y1ISm_o zS@?SPWjG1?DE_7aCiUEg*u$40|0Xxi;`$Aj=N^M~7gX@1!Z{FJ+X7w<+3<=bg-_dL zQCGPZOm!_WIlc&{IBKCL&qLT5$b&+c@?g6DLy&ad41VR;p=cosOFN61xD{_f^XOun z_&Eh{P91>C<)@eu;|DO^gbkZlyI{-6ZBQMrg7h|JREgRLXZB1+uF6XIzC9MkuDig# zjg`O`u7cAp_JT=j0XTer2#JPW;5%O&Z6zwfK0+ARtMtIvzHiWx-UiwVmm$DS17$~B zpy-Y=T85WEQfVt3wbj5DI|0lIlf-35Uqcxqj_GaHz;*FEu;gZ|pWF_C>BaD^ zmk(2Jq(M(<9AvEufSvxOaFr8Fix-}S!9-)QUKt5V=Cv@dG6!mZ$HL;`GezTt-Ua3MKx5Gs-jXxB)mQ{2R~h(fj>40A%EUxOpop-I=idkpu%?$PdCEm zY&#qcV^Cwx7M!$56K#WSaahwG$5TZx@3IxHADWC8+fA@>-4dL4bUJ=>Sc^y2y#aQ( zI$l2$gLZY^2<5Bser+y1>tBF@;%+FlZ69Rls$#wSYB-{1fUXW5AEqRj*Mruku}D1!q} z95CpCJ1(s<#+wIkGvaRqaCor`z8nY4oU{~Uvo_$Fl2qn)*b-DcV2_`&gz@`<`RFKJ z&s=Jl3nr)TfSGt4j*aP|zP$nVc}1|bZv~(xvk0%~nBYvN6mSx}&C(4Mx8(K_~5xpr*A3cg&4I)!}iFaF4@V+ZJQYt!b#D z7KfsnICG|U7gq4iLe_d)yco3|BWsUfMtToy_tHjQ_h}ed^$8%t1!GI+VCIbow9HV) zw|eVwSlJHaKdRuNH?eSoF2&;C9wg|(O7i2ADQ(DBq@{a%7)2~3uclVk!5MOi2E)j+90}vUMq+pgYCZw zPiqG|y|tg5sGdT1w^tKGcuHiWU#z9+9@bQ^ zrJQJ$s?n9^qSX8BJbHZf3@SguM=dP4=(ozNB!*|1zSzm=h;OekAt1Z%8he0}ZYgqP>Pgo##QvLlT?`J`lqD&-lFf`>K!q%e$+zVWRiCN9s2YHk$m`?QdTzL2MN z;j`$1r<1AOusA)s;4QmK+=>PSDO0xzpjuIw zNP{8!nJX`Kls75q&XJ)C3)N`FHG7&_<3#+`^r>{&dt#ozqKP-Yuv^smXndSFmDpEJ zEc(*P;rb`UyLCSGRac@X)`u_|oiWVkLksAfPGcHmwt`lxI?#I!_B1u{J(=(+Al@-2 zN#ETP=-8D4@dxfenW_s6I!r+BhDUJX2^TOn{_w555k%hUfgWoT?$W7(E3XD2L42|auvlHIoBDU>z*L8 zuZu~XrHL`W9>QwdS*R922gej%f@7&E78!4Y(+eW4x;9ypfn#EL%}fJ4e|EyVpPAsC zRS!y7H7Aeb+t30I7pE4nyX~;YY5)#yn~&^^ z(ID$5i!~+sDB5_0xoOY`F-sKaLK?&#@#%!<3~Njny9|z8Pnpn4Me=Co38?)dhEs1= zuqFN80y|_fW@Z#HBRdW-!w=69SD7TxE7OIAoW1wh*Gy>l`OJJf^#Ha%A0ina6|qj8 z%`A7h1o2@@A#2(TczJVC@w$M!?3~I2K<}|YbqyD@`Pl~IYJV9v8ERt1q2o-?X2O6A zkXySw*s~Zf+d-IoH91L;s+IS9PV9kITkgosQF1s-E- z(CBi;>n>kNV`3W|OL)L=N8hm$?li)rzTY6Xpbzd|8iN@rFTp%m4)0FpgFUQbIGs@g zh243e#)%DiXDXwiODkMiD~kNwzhLqG+u)G;iU~b`1$5QT!A5xn8WgobTgW{0T=4)F z^E`!hH{8J>F9!sdX2YXBS#ZCh3G`Gs`Id`waK213@QS?xp%cY0dPy1s+&_R}nlskL ze*^vO2rxeXma$Y{2<1!mz{ZDZ2zG=Sj3J|>fIFBmCsV{X}1NFQ>;N7w5?Q@H`$hR#FH z+6-`ν6W3ZnVco3M1G2n;zki}zu})Yv+>q7(ya+Wo-13qwb_1_<;XgoyLmkgr+{ z_lt!wEWZ*yKCwmqNk`!(cOyV9KRR9(!bXOL_b+?`|E-y@c8(w>ig01tp*lFtu7y+6 z(kx{)Q^6ri8jo`Fp;O*9Lz5pDerPX%@522s0NUX6^fT}ulR=^D9k4*F2rhoDgTRDb z*cG=Av`2(+ypspT)3q>mpA#g;$YLrdzv#d#kD)2TNbFaEGygn@6cxca2lQ}xw*wkC zwt=dp6PEWmprF1O&dAfm?zi)Bb-XN&TKM8~xp{clWC{8<2cea<5yq_aK>p{xXlM1A zDV4Xv*7HLkQt=bsS*=2w@MUOYvmHMs>_VkfAI#mp2Dj>;ffT6$C=`jpnjRbId>n)` z#Eg-3awcZT93{_ZJ%<|?b@5a7KJ4AR4;!2V@b<;8aP5gUis<)2j`I@iubhRxAH%RT zG8%uVDWG~@Fm9Jzgg)a_v3+zKl&s!@HOr^tYU4H7)T4?^?XB^gj~IUGXYjYG5T1GN zk1M`upj_=)aM`a{FKr7NCdu4cS8d$8QgW>1$S_Axf3IoVUd;%UOeXtitG8X{-psn zU30`D#YMPg!xBs$^h9}6bG*3sF|dcMaig>jCY{)aAyKkuw8;xMLNq=(%8w0c9=J#H z7cp9(h%%|m@kQfyth^$QX(zkj^^${(efD&$$zO>2Zrva^W{kfI{E_)4gi$-U!&2T$ z5Vb-JZ!VaI>Pg;stVsIfLnV_xnD4tG);9RTrc2)yc>o{s6>6 zccbi>BpnuWqar`g5FW84(h@d@?%j8h#D%P(;?aUMU_zIsg!xg8Me@{9;|7_PeVla4 zE77ji4Wwt?X(G6wiP(JPqDR~-iOqfiy4!sUEsD-3QlD8=cYO&tYY|GGwh7W!wGOsV zhXNJndq-Z)yhPT#^B{3=dFX=0_cqZ*jP^PKJf0LHVz2x?H3)<)SlVmTw!~TAKA$@j0 zm7MNwBUZb4sO?!J+S;!}T%s+DcfFFN7I)TAXR{baO=mW>Szb>*oOwhttJLW2-eMBH zQ<5sS-yqgE^y!Gw0_xZ3K;v3pkd7(lUk^)aZ>TiF7Em=}RF5J0JzL_Ku z$ASW)<#>fO9=XO8?^U2*{$;j`yok)x)_~3DSntSXO|3Hq@gZbp+Y3E=W?z%TjdZ%unQo2sd4Op3lVC20TpWR0%_?`#jf}(n^h{`Mn#CLnt_Q)>w_)R67T!3K z3TH#Va?bcV;n~p!*z-;gXP+s9(>10T|H8qNFZK(wE}ah>)68J5jU|*P4M0FpI7)cbERRWp!`b?H zsrx=#cdQ=j)h@xNu4l0J%4c|3KEcdWW#J38cbwQ|4%$uSVTQhJhBL{X%%ZscjMlYC z=7fwl@Iu# z8Bn!rC{ZToKrEu2nY~vA_%zPIwh#e;@eX*>!H3VET7j!)Iun@Fc zgk^+Ch=Ce5&AkIRJ$m56JWafQ!VHR)?l5i6aZs^88y5HwJiCf1oc)j-`uPq%jauUGaY0#<8Fe7 zupgZ6C-C7)A>`&jM0$;k)bvL4uRUK9Oay@X}Uk|9sx0f>8^glJ73WHt>! z!uqRVkUJf#xsrj#+v2MgFJO&h3IqnS@Rig|OdKwQgYR=;p413z&F7rmn-yAJ@Vfw6 z4MMo<4IdtS9)V@|BH-}uOyFLw0*{G41`pK09lKK4=GX$}b_>ABhX>CkYQyUr8BEaO zZJ?f21`&-zu=Q&%sJga-N)s0<+lyl9s}xu_7!C7`AHwF>eNgXo8ZJOw%iZr4O26;_H40Tv7wN)kEO4 zN)K;elHW%J<7Df5YW(eYc z1TwY{!MYdV`iCA!n%fIs0w14$V-EGw1!Vv$qi*PPWFWDH=Fe zi4&U~HpP$Y43T%*Bv?=&j< z6<4=8VGAd>m}avKS@VT4T-F6wMC8H!j%g^iWGm7RBV2i!L7OTAjBHwnZ6WKh_PGfj zXbMA{U(1l5T#Zp>ZjgWLJXkD>!*A`vID0OG;^*{mlKEwDIH`jsxdFJ_rT`KjE2Cg< z3XJG}0`2!4jSWXPQ*b{BKZOZk)dCBQaE(Q!jq72Ttu!Wh>*2|X<#>Ov0ebejq3+M8 zu&|0jP!`Aip>N^rNeb+3>d0*H!O9(rFzc=mj&_g1%?dfZ5Iq44?}*^BIa-j+ABKLT zpP+oh7F;4W2^B?8g0yEamKyV8(a}H@I?j)1Vu$8D(=lG|H(0)u1Cv01yzIOYuVEYb zritOjyn}eXARI^2reN#6c1YOd#XL~ihnEdK@%UI25)T$mHh02mK~0pHy$pZlT44K< zAMhCigv03sQYFb-STfyS}fL| z`tOyGS}dKvRzCiZB30wBrH_A;V9pf&!=OIZfB9eQKh0TF|C)ip-^7<&{Xzpd3e=t6 ze}uVSfxh!>C;#6G8P75|u`sr@oMmimG0W7-(<*{G+B;=br{};aKl4>Hm+G_g``U(PK6L3Agb-!Tq~F z_g_K((Fy-d@&7J$y8i*$|8o6*Mg2#oG5_C?oI4l2{{Z#>?#@db`T04$|5lWF^#ALQ w#rmUi|5N_^y60~`%3_WFS7T1Uzr6}K{Cm5<3%>dONS--;1XzFC|Gn-10+338UjP6A diff --git a/fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt b/fme/core/models/conditional_sfno/testdata/test_hybridnet_prognostic_output.pt deleted file mode 100644 index ecb45d27fa330e2489ff9acf69d66af27e13a7ab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7064 zcmbW62{@J8_s37hL#9NC(1ao(bKbp9l5U|?DotdJ%tN}VWGq7|C6qC#P*Nho+3Pe3 zX+%Xr8c30D8Z`d7_cz>of4}?dKL7PR>v_*P@7~|_J^NkndCvRUwpRQCEEb#1`j42# zGGPU~1qC~LY;g|sboF%$cJvSQ+u-XL6zu8Z=(j1@e^c;GSEpbnZGSIsOLGI(>iJ8la0^bQhr40LneT5-b|Pp0+eVBsf5HX~1Z0 zTPtDq>r=%Z$A^_i()KS`O^0Rn<5kmk_Y3rK3U&-}3k(`AY-w)%|69D_1=k(1R{pvz zYUIn0Bx#Zvwo*8Z``j!xd58+?7-e1k2`$Nvv^hWaB0U~|9=zTx5{`PT-fv($cU zzuoHKmwT!pA`GUi@7gyu5TqWUs9 z^-hXYeB2$*g-%AvcWq>Md>6OHIv>8s{laZ2G=xbHjOk==U2<7$9k$D7!Y{Xzh>znP z$g)i^OOs5aHs24^Q9>{1%#5qdzD5-g@|%o`8^2eJKNIAM+ummC=u%Rqa-8OPsKF$j z3huJ4gR5C5V9wXQBs@+ZhdTJ^-Xt$FvC#(ih-=}Q6o5+eFx)I3Mf;-IJc&4O@Da!X zjeN7&OF#M0P6;u3>yrjfSBvBBOb8*nYjdD-$pCXW)(vhZCD5*{GT8gJof|#b3-;Tx zVe5VakiL78D=9w%^0g)~`4{)&K^YG;PAEi^>JFIBBoW7~1em15f-o!rReo!j(Ncyg zCWWL%Oc>KU(?DSyMLN?Qi!SEElAUL0_VOH*s@qG0^iJT{bP3QBD#Af_BsdDzqS=%{ z{Q3GpWG|A3$ZgRevPvFwZ#9Csj}_?ey2*??wUK(R<4KbA+Gct5C%Fw2jAzSb&LFjq%fp0+_K`5vw#WlTOb? zu=WPx_%)A+jn8!ehdD49-9rn674XYiSK4u5J(hf21z{?mF`^(B#ojd0P@!GW`B;T4 zl3N7t%w=%fg&6WqdOfc9btXB&rj+f!he+27lY35@Xnmp(c6EhA$s|D-7##^!v4 zxP}YA$Ke9EyVUi@AVj?2P=QD3usn`Jb3r!vu#}LknShngL#b(CIiqx7Hkc0O!92}9 z_p4i<=Z1iTTpUxo#2H>$l#)8huhjpM5Kgx;1WxYTYL`nVQDY$+jnzdluYNo7>hDzL z9gD(^Eo0$q#CG_6sDP>$beP!}vvJeQbEte?4boPb(NCfg)Zxqouy4q~fz+Dn0qN(= z-3#NPc=JV6jte7+(tPND%g4{&BH02Ve%)CgSbzeg%5fx&kTufx#2(}v?BDb9t z(27qD90JcVQ_V|AyG$Sks6L`HE*xyK?`(4oUb&IS z-R)sTnSxp}CMc4G96m(f<}AXn+%*UbuW`d-Md{~lJMr#Miabt;VAb6*rf{T97xYt| z82RY==>K{aasSN1e*X>F;$VmS8jMhUUlk~YS;E?!SM_Vidneyaw6<~a*&EG7XY$sfW&enc+0G- zG0#R7OB2rnGq#Gn@9}1QUn!#D`U)_-?}GacU1>~>5=nC9=e2v?$4!@yVO2&E_6+ym zI&E*bIYR)#D-Gbrq^&UFnGIsE2%az+gVk&9&~(mwTvHW>DS9Q;#W)GfB2E(b1N}74 zDjWhhDkOG_1Vp5Sp|byTOu0OobBH+&6{Uf2{zexaotZ+XCO*Ln@!HVrmxt98pTe4p zCb+iu0Djq|M_9xJUf6C2ukKt}t&#y18RwXfcAn_IS%@GKy&}qKr>SXY2_s_^Pe*S{1+~ZX z!1(ShXv`{tnH5rCJWZB&f1wikSllGz!~*GUzf@d33dvFtBMjZN8=vX*Q2jA|p#SPR z4%zi^O=P|?vRn(|JA7aJT6^K^%1-dp@}=D_rBET(YnFVZ5UiW7!0rCAjNS?@aM$8M zU+^h>u5Qm|lsA%;f@0XNHxYd#7)o9v|_Z%o%9xJIJ*;9f7>_ zhGyaSKa**?tMHjzEcqhH!!!BM@#LjCDrU436&@?VMzQsTIK-f?W+S!u6vaF%jw9dB z)q;w7JWaPX1kVjMRM>-qk7RbzkLyBdOk6${Y_$UZldI5bxhQzwnLrn4Y=r2Z6s)L= zL-%4SdT^c^OPq^;!Hz^*w_u#h1n}|DB+Hh_@g{9o$F`wvIAnGKrTdieO3XSG zIiZOAKBZu{LW>@=4IMoZ<9=GExwXZnc7f!vJkoE;yAnJ2(+79ArU$ODdz-0KPUorR=3ix zw{=K^(Ou@DWG->WV(Mui%&|Lp4#H#BLAq)wE*bxZ^vBB3r~RqelgNiJ+{t29Pl=wW2%Cu950d-TKAg*-8pbbRv3gT|ct85ijqL3j|0iH-BY zV<8;ev#SH$N~d!t&szu|?irE8SLZ;e-CUSt-^K_k_~S3HpF@v94^tg@0K-bdiSOzz zd?zvnLN{#Wc4{=?p7;>7N(g|7wK3$ahAoMia|z0CmV#St8kEj6zkJ29w`JPCAFDihyugjT${NCmHo@eYec;ZN)#ia9u`qu+2;@axTnR zP!9Mg2j_av!~O+Z$y9YqjBG#3bZU2!pNdm(sznDnSPOAjFFAB_#0HSdsv=XWMq}rq zI?Q+&iVs@laC>PKI%zINjhZsBe5=86=j$MHWtrgq^B1!5b_=6jHv`{m%Q4p)OUUxD zMYvO=9Da7ng8i+Iu;FDD*UIiKOz@Euuiug1KlpX2ax_YpLtKWo(MmhI{25)L!@#b+cOmQij?@?^Gf8c2E+Q#h)QQSq@m>oP*!q97Ux7f4n6A z3B{X*;Oz}}Zm+x{&NemU^pfLLfA$vQCA1jq>Ktj;LlMyJIEHgq%5tP0uAm<0hlr5j zHYy&x4+O`)!qT_VC~2z5o!+*Ei5!=SrzhLs1?zC|l8ivMY6Gko{;$%Ms>op%R1iOn zU9exsold#Nqpyx%!`f3{==E9$sIJ|FYC|fh&@)IC?x{j{Z6?j@iZxptbqUAqkmb~H zci^5att3C)7siMr0PFn%x-2aO`R8WB43k0V-gSzZbo4OtFEj<4?j+jVUjfY)m&jLp z9-L8J4@(V|Az*hW5qrKGp2o4^3x6WBaaS33$j6a{@*1ezFU?VPT>_G^g`o4~47P6O zF&b`KAUZmfhIYl1+o^uc{WtTmc#0N`N>@R?vES*6$LVzNWhdjLtP5lOBXL2~Dd2O6 zqQcs9pxQ&6)URqq%l8GSty+u|qbzBKcNCots+2H&l96oN92A}y~@xjw<_%OH< z<0aX=#~XIQbjMAozhXU1eItcMGFh3)M zd3>;jp4ZREaj6Cn9vp~A+7=VPIjMA8RRqfGRHDS0XSnm+T9oHu4PDo?s)4fM`!<4HB@x7cX z-uyygkM(ivZ_6UKqrA9VKeXV*m<-|>ITO=*RWWCpJfNT66lCvm~d1u`QiE)iYb^3WQ z=@a63S?|L&1|4+Og*I{__%KF}ZzO$s%jnMnq}p)J3n-;5(7OKwtLl4EveX5RRt_*t zVR5)HxdQVOE8&~vSmb}!K@zW4k)J#tQ`K*aU|o?1o_}J>ox6vHPNp~TN@)RbOcXI~ z-Wjx1e}=Q~DZ{2@K2-cT2JO4T@Wfhq+@9_X&8>ZycFr4D>rY0$^9m5SeipeF<$;y1 zdvQa2IGxxq8H$}s$#n^Sj@epw3|}e4%lDlQKJ^K}lq>R9r18^NUrwO(n$7q*-w4Od zY9rfw*}&I57VEgqMAK~$KAZQ#amAH%xs?pmt2dLK)AZ>}Ee2d)7ogI|2b9g1!8O*m zhj0CvbZz)pUj7Sf*cI^-WOF0v?zvatm##;&aG5UbTOO<*NerHRD_!IHV^9Ed4bc*&XkKrwU zA48L(BQQlY6Yi)df?0_Om9gUE$)B4JLgDf7dbSlx3M1sY^+94v1|E1g1CAZaLdDuN zT>7;HRjc^8Jh*dUA-T3&kaOEe69o%Z z;iJ3;3`R+CPIs(E+n3R#B%%i-GNOpmf@pv$HO{rK8d$D-1SZmhSjs+2WM>B8vf*cD zA0;)oJuRF#EG_`M&%NZa^a03^E~JK63yCHE5!?#%HctrEQwyW!0bgH*1a50xYp;p~S(a(>%=lrcoC&>V-~+p8f^u9_Mzu0#KR2Sz;PF{4(p z7d2S|V38BUJ$fyPy8rZo$Rs|1teq`%_2pR1zLt#vJQrxD0T`M)2X47v$A%;>G>okP zgTj>{Fn&G6?|4j?rX^#Wi70-l3c{${33OM3HBqBB%n8$8SkjyacPD0J$)z+Z5|s>D zHf3BZQDHh&R)$CW-cl{*ZQFePNAnMPYQm7T9glXuvnYuykv-8Vpr))o-+rT}8P<(rv& z{atF3SwrrWNzksytN2{5kZzWf0DblxJh$Zwat>Vsl`=gD(al3=&9mTOpbhy2d0c~j z0giZr6UeMwPGwbxdBu%vH13|p%rtw7FUuBS{H6%pBHc%BE!>Jx+6z$a#sp65U=tHk zZGemGR5`|%_@QU91LT=3f`${y@aSMI)`fl`#YOk&nVL)trhAe7=o0ObEu5X)e~-J7 zPmcukYT%suJjy4?rrCN$B&VpE78KpYFRxgzOyqVZWh@QloGQADfJj%MZMP z4vq|qHIjc#E;U%we@rg_MU>Tkj4ppgsIx>z9Nfv;u>8mT)`+F>V+Y#5!qXtPK+j=< z9O^Uz)*U@vEzM>ByCGczBRykX6B7enU1I}%L;YDgdWJgsCdMN%%TU)u*U(5;U*Fhd z#AU#JWoeCJx;v7#R-;A?|7+*s_jG>R$caD1-`>agv+VE5y3QXY4ToX-4`u%Z`v0u{ zdu%@I59)XRmim9N`ac)=Jrke)hXPB7SMP`R|Hb0}to!?pRsVzTy}zUT8(;sk={5-yKaXL* m_9%;0F Date: Wed, 25 Mar 2026 17:08:09 +0000 Subject: [PATCH 20/31] Address PR review feedback: update defaults, remove duplicate docstring, add tests - Change LocalNetBuilder default block_types to 4 disco + 4 conv1x1 blocks - Remove duplicate __init__ docstring from NoiseConditionedModel (keep class-level) - Add tests for secondary config output_names property permutations - Add test for secondary residual with residual_prediction=True scenario Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 11 +++- fme/ace/registry/stochastic_sfno.py | 16 ----- fme/core/step/test_step.py | 95 +++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 17 deletions(-) diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index 59481ad9d..bf37a100b 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -139,7 +139,16 @@ class LocalNetBuilder(ModuleConfig): noise_type: Literal["gaussian", "isotropic"] = "gaussian" context_pos_embed_dim: int = 0 block_types: list[BlockType] = dataclasses.field( - default_factory=lambda: ["disco"] * 12 + default_factory=lambda: [ + "disco", + "disco", + "disco", + "disco", + "conv1x1", + "conv1x1", + "conv1x1", + "conv1x1", + ] ) global_layer_norm: bool = False use_mlp: bool = True diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 479d844fb..3dab556d0 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -74,22 +74,6 @@ def __init__( lmax: int = 0, mmax: int = 0, ): - """ - Args: - module: An nn.Module with forward signature (x, context: Context). - img_shape: Global spatial dimensions (lat, lon) of the input data. - embed_dim_noise: Dimension of noise channels. - embed_dim_pos: Dimension of learned positional embedding. - 0 disables. - embed_dim_labels: Dimension of label embeddings. 0 disables. - inverse_sht: Optional inverse spherical harmonic transform. When - provided, noise conditioning uses isotropic noise generated - via the inverse SHT. When None, gaussian noise is used. - lmax: Maximum spherical harmonic degree for isotropic noise. - Required when inverse_sht is provided. - mmax: Maximum spherical harmonic order for isotropic noise. - Required when inverse_sht is provided. - """ super().__init__() self.module = module self.embed_dim = embed_dim_noise diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index 90c0a7424..f2271eaa9 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -890,6 +890,101 @@ def test_secondary_network_state_round_trip(): torch.testing.assert_close(out1[name], out2[name]) +def test_secondary_output_names_full_field_only(): + """secondary_out_names appear in output_names.""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["c"], + ) + assert "c" in config.output_names + assert "b" in config.output_names + + +def test_secondary_output_names_residual_only(): + """secondary_residual_out_names that are in out_names appear in output_names.""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["b"], + ) + assert "b" in config.output_names + + +def test_secondary_output_names_residual_on_input_only(): + """secondary_residual_out_names on input-only name adds to output_names.""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a", "b"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["a"], + ) + assert "a" in config.output_names + assert "b" in config.output_names + + +@pytest.mark.parallel +def test_secondary_residual_on_input_only_with_residual_prediction(): + """When residual_prediction=True and secondary_residual_out_name is in in_names + but not out_names, the input should not be added twice.""" + torch.manual_seed(0) + normalization = get_network_and_loss_normalization_config( + names=["forcing", "prog_a", "prog_b"], + ) + config = StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing", "prog_a", "prog_b"], + out_names=["prog_a", "prog_b"], + normalization=normalization, + residual_prediction=True, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["prog_a"], + ), + ), + ) + img_shape = DEFAULT_IMG_SHAPE + step = get_step(config, img_shape) + assert "prog_a" in step.output_names + input_data = get_tensor_dict(step.input_names, img_shape, n_samples=2) + next_step_input_data = get_tensor_dict( + step.next_step_input_names, img_shape, n_samples=2 + ) + output = step.step( + args=StepArgs( + input=input_data, next_step_input_data=next_step_input_data, labels=None + ), + ) + assert output["prog_a"].shape == (2, *img_shape) + assert output["prog_b"].shape == (2, *img_shape) + + def test_step_with_prescribed_prognostic_overwrites_output(): normalization = get_network_and_loss_normalization_config( names=["forcing_shared", "forcing_rad", "diagnostic_main", "diagnostic_rad"], From 0009d7e63d43604886d88e7848435cf1f263b47a Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 25 Mar 2026 17:50:52 +0000 Subject: [PATCH 21/31] Fix backwards compatibility: restore conditional_model attribute name The rename from NoiseConditionedSFNO to NoiseConditionedModel also changed the wrapped module attribute from conditional_model to module, breaking frozen checkpoint loading. Restore the original attribute name. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/stochastic_sfno.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index 3dab556d0..c0df4483a 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -75,7 +75,7 @@ def __init__( mmax: int = 0, ): super().__init__() - self.module = module + self.conditional_model = module self.embed_dim = embed_dim_noise self.img_shape = img_shape self._inverse_sht = inverse_sht @@ -138,7 +138,7 @@ def forward( else: embedding_pos = None - return self.module( + return self.conditional_model( x, Context( embedding_scalar=None, From 607932ff90cea9be043641371221650e2e56705b Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 25 Mar 2026 21:47:36 +0000 Subject: [PATCH 22/31] Add kernel_shape, basis_type config to LocalNet and isotropic Morlet basis Expose kernel_shape and basis_type as configurable options on LocalNetConfig and LocalNetBuilder, replacing the hardcoded (3, 3) morlet defaults in DISCO convolution layers. Add an isotropic Morlet filter basis (radial-only modes) via monkey-patch of torch-harmonics get_filter_basis, available as the "isotropic morlet" basis type. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 14 +++ .../conditional_sfno/isotropic_filter.py | 100 ++++++++++++++++++ fme/core/models/conditional_sfno/localnet.py | 32 ++++-- fme/core/models/conditional_sfno/sfnonet.py | 1 + .../models/conditional_sfno/test_localnet.py | 41 +++++++ 5 files changed, 182 insertions(+), 6 deletions(-) create mode 100644 fme/core/models/conditional_sfno/isotropic_filter.py diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index bf37a100b..1f4753d59 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -14,6 +14,7 @@ ) from fme.core.models.conditional_sfno.layers import Context, ContextConfig from fme.core.models.conditional_sfno.localnet import ( + BasisType, BlockType, LocalNetConfig, get_lat_lon_localnet, @@ -109,6 +110,15 @@ class LocalNetBuilder(ModuleConfig): Attributes: embed_dim: Dimension of the embeddings. + kernel_shape: Shape of the DISCO convolution filter basis, passed + to the filter basis constructor. For the "piecewise linear" and + "morlet" basis types this is a two-element tuple + (n_radial_modes, n_azimuthal_modes). When n_azimuthal_modes is + 1, the "piecewise linear" basis produces isotropic (radially + symmetric) filters. Only affects 'disco' blocks. + basis_type: Type of filter basis for the DISCO convolution + ('morlet', 'piecewise linear', or 'zernike'). Only affects + 'disco' blocks. noise_embed_dim: Dimension of the noise conditioning channels. noise_type: Type of noise for conditioning ('gaussian' or 'isotropic'). Isotropic noise is generated via inverse spherical harmonic @@ -135,6 +145,8 @@ class LocalNetBuilder(ModuleConfig): """ embed_dim: int = 256 + kernel_shape: tuple[int, int] = (3, 3) + basis_type: BasisType = "morlet" noise_embed_dim: int = 256 noise_type: Literal["gaussian", "isotropic"] = "gaussian" context_pos_embed_dim: int = 0 @@ -170,6 +182,8 @@ def build( ) -> nn.Module: params = LocalNetConfig( embed_dim=self.embed_dim, + kernel_shape=self.kernel_shape, + basis_type=self.basis_type, block_types=self.block_types, global_layer_norm=self.global_layer_norm, use_mlp=self.use_mlp, diff --git a/fme/core/models/conditional_sfno/isotropic_filter.py b/fme/core/models/conditional_sfno/isotropic_filter.py new file mode 100644 index 000000000..df9e2c5ec --- /dev/null +++ b/fme/core/models/conditional_sfno/isotropic_filter.py @@ -0,0 +1,100 @@ +"""Isotropic Morlet filter basis and monkey-patch for torch-harmonics. + +This module defines an isotropic variant of the Morlet filter basis where +all basis functions depend only on radial distance, not on azimuthal angle. +It monkey-patches ``torch_harmonics.filter_basis.get_filter_basis`` (and the +copy already imported into ``torch_harmonics.disco.convolution``) so that +the basis type ``"isotropic morlet"`` is available everywhere DISCO +convolutions are constructed. + +This is a temporary measure until the change is contributed upstream or the +dependency is forked. +""" + +import math + +import torch +from torch_harmonics import filter_basis as _fb +from torch_harmonics.disco import convolution as _conv + + +class IsotropicMorletFilterBasis(_fb.FilterBasis): + """Morlet-style filter basis using only radial modes. + + Each basis function is a product of a Hann radial window and a 1-D + Fourier harmonic in the normalised radial coordinate ``r / r_cutoff``. + Because none of the basis functions depend on the azimuthal angle + ``phi``, any learned linear combination is guaranteed to be isotropic + (radially symmetric). + + ``kernel_shape`` is a single integer giving the number of radial modes. + If a tuple is provided, only the first element is used. + """ + + def __init__( + self, + kernel_shape: int | tuple[int] | tuple[int, int], + ): + if isinstance(kernel_shape, list | tuple): + kernel_shape = kernel_shape[0] + if not isinstance(kernel_shape, int): + raise ValueError( + f"expected kernel_shape to be an integer but got " + f"{kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self) -> int: + return self.kernel_shape + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 1.0, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + r_norm = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + n = ikernel[iidx[:, 0], 0, 0] + + # Radial Fourier modes: cos/sin pattern identical to Morlet but in r. + harmonic = torch.where( + n % 2 == 1, + torch.sin(torch.ceil(n / 2) * math.pi * r_norm / width), + torch.cos(torch.ceil(n / 2) * math.pi * r_norm / width), + ) + + # Hann radial envelope + window = torch.cos(0.5 * torch.pi * r_norm / width) ** 2 + vals = window * harmonic + + return iidx, vals + + +# --------------------------------------------------------------------------- +# Monkey-patch get_filter_basis +# --------------------------------------------------------------------------- + +_original_get_filter_basis = _fb.get_filter_basis + + +@_fb.lru_cache(typed=True, copy=False) +def _patched_get_filter_basis( + kernel_shape: int | tuple[int] | tuple[int, int], + basis_type: str, +) -> _fb.FilterBasis: + if basis_type == "isotropic morlet": + return IsotropicMorletFilterBasis(kernel_shape=kernel_shape) + return _original_get_filter_basis(kernel_shape, basis_type) + + +_fb.get_filter_basis = _patched_get_filter_basis +_conv.get_filter_basis = _patched_get_filter_basis diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py index 82347707c..e0806a5a8 100644 --- a/fme/core/models/conditional_sfno/localnet.py +++ b/fme/core/models/conditional_sfno/localnet.py @@ -9,11 +9,15 @@ from fme.core.distributed import Distributed from .initialization import trunc_normal_ +from .isotropic_filter import ( + IsotropicMorletFilterBasis as _IsotropicMorletFilterBasis, # noqa: F401 (monkey-patches torch-harmonics on import) +) from .layers import MLP, ConditionalLayerNorm, Context, ContextConfig from .lora import LoRAConv2d from .sfnonet import DiscreteContinuousConvS2, NoLayerNorm, _compute_cutoff_radius BlockType = Literal["disco", "conv1x1"] +BasisType = Literal["morlet", "isotropic morlet", "piecewise linear", "zernike"] @dataclasses.dataclass @@ -23,6 +27,11 @@ class LocalNetConfig: Attributes: type: Discriminator tag for union deserialization. embed_dim: Dimension of the embeddings. + kernel_shape: Shape of the DISCO convolution filter basis, passed + to the filter basis constructor. Only affects 'disco' blocks. + basis_type: Type of filter basis for the DISCO convolution + ('morlet', 'isotropic morlet', 'piecewise linear', or 'zernike'). + Only affects 'disco' blocks. block_types: List of filter types for each block ('disco', 'conv1x1'). The length determines the number of blocks. global_layer_norm: Whether to reduce along the spatial domain when @@ -43,6 +52,8 @@ class LocalNetConfig: type: Literal["localnet"] = "localnet" embed_dim: int = 256 + kernel_shape: tuple[int, int] = (3, 3) + basis_type: BasisType = "morlet" block_types: list[BlockType] = dataclasses.field( default_factory=lambda: ["disco"] * 12 ) @@ -63,8 +74,7 @@ def __post_init__(self): for i, bt in enumerate(self.block_types): if bt not in valid: raise ValueError( - f"Invalid block type {bt!r} at index {i}, " - f"must be one of {valid}" + f"Invalid block type {bt!r} at index {i}, must be one of {valid}" ) @@ -94,6 +104,8 @@ def __init__( embed_dim, img_shape: tuple[int, int], filter_type="disco", + kernel_shape: tuple[int, int] = (3, 3), + basis_type: BasisType = "morlet", data_grid="equiangular", lora_rank: int = 0, lora_alpha: float | None = None, @@ -104,16 +116,16 @@ def __init__( nlat, nlon = img_shape theta_cutoff = 2 * _compute_cutoff_radius( nlat=nlat, - kernel_shape=(3, 3), - basis_type="morlet", + kernel_shape=kernel_shape, + basis_type=basis_type, ) self.filter = DiscreteContinuousConvS2( embed_dim, embed_dim, in_shape=img_shape, out_shape=img_shape, - kernel_shape=(3, 3), - basis_type="morlet", + kernel_shape=kernel_shape, + basis_type=basis_type, basis_norm_mode="mean", groups=1, grid_in=data_grid, @@ -143,6 +155,8 @@ def __init__( img_shape: tuple[int, int], context_config: ContextConfig, filter_type="disco", + kernel_shape: tuple[int, int] = (3, 3), + basis_type: BasisType = "morlet", data_grid="equiangular", global_layer_norm: bool = False, mlp_ratio=2.0, @@ -173,6 +187,8 @@ def __init__( embed_dim, img_shape=img_shape, filter_type=filter_type, + kernel_shape=kernel_shape, + basis_type=basis_type, data_grid=data_grid, lora_rank=lora_rank, lora_alpha=lora_alpha, @@ -343,6 +359,8 @@ def __init__( self.in_chans = in_chans self.out_chans = out_chans self.embed_dim = params.embed_dim + self.kernel_shape = params.kernel_shape + self.basis_type = params.basis_type self.num_layers = len(params.block_types) self.use_mlp = params.use_mlp self.encoder_layers = params.encoder_layers @@ -401,6 +419,8 @@ def __init__( img_shape=self.img_shape, context_config=context_config, filter_type=block_type, + kernel_shape=self.kernel_shape, + basis_type=self.basis_type, data_grid=self.data_grid, global_layer_norm=self.global_layer_norm, mlp_ratio=self.mlp_ratio, diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 1dbc5e7a8..36f02a1e6 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -119,6 +119,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): theta_cutoff_factor = { "piecewise linear": 0.5, "morlet": 0.5, + "isotropic morlet": 0.5, "zernike": math.sqrt(2.0), } diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py index 307b6cad2..94cc48f89 100644 --- a/fme/core/models/conditional_sfno/test_localnet.py +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -226,6 +226,47 @@ def test_no_big_skip(): assert output.shape == (n_samples, output_channels, *img_shape) +@pytest.mark.parametrize( + "kernel_shape, basis_type", + [ + ((1, 1), "morlet"), + ((5, 1), "morlet"), + ((3, 5), "morlet"), + ((3, 3), "piecewise linear"), + ((3, 1), "piecewise linear"), + ((3, 3), "isotropic morlet"), + ((5, 3), "isotropic morlet"), + ], +) +def test_can_call_localnet_with_kernel_shape(kernel_shape, basis_type): + input_channels = 2 + output_channels = 3 + img_shape = (9, 18) + n_samples = 4 + device = get_device() + params = LocalNetConfig( + embed_dim=16, + kernel_shape=kernel_shape, + basis_type=basis_type, + block_types=["disco", "disco"], + ) + model = get_lat_lon_localnet( + params=params, + img_shape=img_shape, + in_chans=input_channels, + out_chans=output_channels, + ).to(device) + x = torch.randn(n_samples, input_channels, *img_shape, device=device) + context = Context( + embedding_scalar=torch.randn(n_samples, 0, device=device), + labels=torch.randn(n_samples, 0, device=device), + noise=torch.randn(n_samples, 0, *img_shape, device=device), + embedding_pos=torch.randn(n_samples, 0, *img_shape, device=device), + ) + output = model(x, context) + assert output.shape == (n_samples, output_channels, *img_shape) + + def test_unknown_filter_type_raises(): with pytest.raises(ValueError, match="Invalid block type"): LocalNetConfig(block_types=["spectral"]) # type: ignore[list-item] From 4c98ea37f6906e9bae1a1eb1c1589cc9908c21e3 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 25 Mar 2026 21:56:09 +0000 Subject: [PATCH 23/31] Add test that isotropic DISCO conv commutes with latitude reflection Verifies that the isotropic Morlet basis produces truly isotropic filters by checking that flip(conv(x)) == conv(flip(x)) under latitude reflection, an isometry that preserves geodesic distance but reverses local azimuthal orientation. Also verifies the standard Morlet basis does NOT commute, confirming the test discriminates isotropic from anisotropic filters. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../models/conditional_sfno/test_localnet.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py index 94cc48f89..f6480fad2 100644 --- a/fme/core/models/conditional_sfno/test_localnet.py +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -9,6 +9,7 @@ from .layers import Context, ContextConfig from .localnet import LocalNetConfig, get_lat_lon_localnet +from .sfnonet import DiscreteContinuousConvS2, _compute_cutoff_radius DIR = os.path.dirname(os.path.abspath(__file__)) @@ -267,6 +268,71 @@ def test_can_call_localnet_with_kernel_shape(kernel_shape, basis_type): assert output.shape == (n_samples, output_channels, *img_shape) +def test_isotropic_disco_conv_commutes_with_latitude_reflection(): + """An isotropic DISCO conv must commute with latitude reflection. + + Latitude reflection (θ → π−θ) is an isometry of the sphere that maps + equiangular grid points exactly onto other grid points. For an + isotropic filter the convolution kernel depends only on geodesic + distance, which is preserved by the reflection, so + ``flip(conv(x)) == conv(flip(x))``. An anisotropic filter also + depends on the local azimuthal angle, which the reflection reverses, + so the equality generally fails. + """ + torch.manual_seed(0) + img_shape = (16, 32) + embed_dim = 4 + kernel_shape = (3, 3) + data_grid = "equiangular" + lat_dim = 2 # dimension index for latitude in (batch, channel, lat, lon) + + def make_disco_conv(basis_type): + theta_cutoff = 2 * _compute_cutoff_radius( + nlat=img_shape[0], + kernel_shape=kernel_shape, + basis_type=basis_type, + ) + return DiscreteContinuousConvS2( + embed_dim, + embed_dim, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=kernel_shape, + basis_type=basis_type, + basis_norm_mode="mean", + groups=1, + grid_in=data_grid, + grid_out=data_grid, + bias=False, + theta_cutoff=theta_cutoff, + ) + + x = torch.randn(1, embed_dim, *img_shape) + x_flipped = torch.flip(x, dims=[lat_dim]) + + # Isotropic basis: convolution must commute with the reflection. + iso_conv = make_disco_conv("isotropic morlet") + with torch.no_grad(): + out, _ = iso_conv(x) + out_from_flipped, _ = iso_conv(x_flipped) + torch.testing.assert_close( + torch.flip(out, dims=[lat_dim]), + out_from_flipped, + ) + + # Anisotropic basis: convolution should NOT commute (random weights + # activate the azimuthal modes that break the symmetry). + aniso_conv = make_disco_conv("morlet") + with torch.no_grad(): + out, _ = aniso_conv(x) + out_from_flipped, _ = aniso_conv(x_flipped) + assert not torch.allclose( + torch.flip(out, dims=[lat_dim]), + out_from_flipped, + atol=1e-5, + ) + + def test_unknown_filter_type_raises(): with pytest.raises(ValueError, match="Invalid block type"): LocalNetConfig(block_types=["spectral"]) # type: ignore[list-item] From c0a911dc5da39462f8bde97bacbe6effbb5b4748 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 25 Mar 2026 22:10:45 +0000 Subject: [PATCH 24/31] Fork DISCO convolution with FFT-based contraction into fme/core/disco MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fork DiscreteContinuousConvS2 from torch-harmonics (fft-disco-contraction branch) into fme/core/disco/, using FFT-based contraction as the sole implementation path. This replaces the O(nlon) loop of sparse BMMs with a single FFT → banded pointwise multiply → inverse FFT pipeline, giving a significant speedup for LocalNet. The non-distributed and data-parallel backends now use the forked class; the spatial-parallel backend (model_torch_distributed) continues to use the torch-harmonics distributed version unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/disco/__init__.py | 3 + fme/core/disco/_convolution.py | 369 ++++++++++++++++++++++ fme/core/disco/_disco_utils.py | 136 ++++++++ fme/core/disco/test_disco.py | 121 +++++++ fme/core/distributed/non_distributed.py | 4 +- fme/core/distributed/torch_distributed.py | 4 +- 6 files changed, 635 insertions(+), 2 deletions(-) create mode 100644 fme/core/disco/__init__.py create mode 100644 fme/core/disco/_convolution.py create mode 100644 fme/core/disco/_disco_utils.py create mode 100644 fme/core/disco/test_disco.py diff --git a/fme/core/disco/__init__.py b/fme/core/disco/__init__.py new file mode 100644 index 000000000..9de56cc9e --- /dev/null +++ b/fme/core/disco/__init__.py @@ -0,0 +1,3 @@ +from ._convolution import DiscreteContinuousConvS2 + +__all__ = ["DiscreteContinuousConvS2"] diff --git a/fme/core/disco/_convolution.py b/fme/core/disco/_convolution.py new file mode 100644 index 000000000..ef7a6d540 --- /dev/null +++ b/fme/core/disco/_convolution.py @@ -0,0 +1,369 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause +# +# Modifications: stripped down to DiscreteContinuousConvS2 (forward only) +# using the FFT-based contraction as the sole implementation path. CUDA +# kernel and torch sparse-matrix fallbacks are removed. + +import abc +import math + +import torch +import torch.nn as nn +from torch_harmonics import filter_basis as _filter_basis_module +from torch_harmonics.cache import lru_cache +from torch_harmonics.filter_basis import FilterBasis +from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes + +from ._disco_utils import _disco_s2_contraction_fft, _get_psi, _precompute_psi_banded + + +def _normalize_convolution_tensor_s2( + psi_idx, + psi_vals, + in_shape, + out_shape, + kernel_size, + quad_weights, + transpose_normalization=False, + basis_norm_mode="mean", + merge_quadrature=False, + eps=1e-9, +): + """Normalizes convolution tensor values based on specified normalization + mode. + """ + if basis_norm_mode == "none": + return psi_vals + + idx = torch.stack( + [ + psi_idx[0], + psi_idx[1], + psi_idx[2] // in_shape[1], + psi_idx[2] % in_shape[1], + ], + dim=0, + ) + + ikernel = idx[0] + + if transpose_normalization: + ilat_out = idx[2] + ilat_in = idx[1] + nlat_out = in_shape[0] + correction_factor = out_shape[1] / in_shape[1] + else: + ilat_out = idx[1] + ilat_in = idx[2] + nlat_out = out_shape[0] + + q = quad_weights[ilat_in].reshape(-1) + + vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) + support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) + + for ik in range(kernel_size): + for ilat in range(nlat_out): + iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) + vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx]) + support[ik, ilat] = torch.sum(q[iidx]) + + for ik in range(kernel_size): + for ilat in range(nlat_out): + iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) + + if basis_norm_mode == "individual": + val = vnorm[ik, ilat] + elif basis_norm_mode == "mean": + val = vnorm[ik, :].mean() + elif basis_norm_mode == "support": + val = support[ik, ilat] + elif basis_norm_mode == "none": + val = 1.0 + else: + raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.") + + psi_vals[iidx] = psi_vals[iidx] / (val + eps) + + if merge_quadrature: + psi_vals[iidx] = psi_vals[iidx] * q[iidx] + + if transpose_normalization and merge_quadrature: + psi_vals = psi_vals / correction_factor + + return psi_vals + + +@lru_cache(typed=True, copy=True) +def _precompute_convolution_tensor_s2( + in_shape: tuple[int], + out_shape: tuple[int], + filter_basis: FilterBasis, + grid_in: str | None = "equiangular", + grid_out: str | None = "equiangular", + theta_cutoff: float | None = 0.01 * math.pi, + theta_eps: float | None = 1e-3, + transpose_normalization: bool | None = False, + basis_norm_mode: str | None = "mean", + merge_quadrature: bool | None = False, +): + r"""Precomputes the rotated filters at positions using YZY Euler angles.""" + assert len(in_shape) == 2 + assert len(out_shape) == 2 + + kernel_size = filter_basis.kernel_size + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + lats_in, win = precompute_latitudes(nlat_in, grid=grid_in) + lats_out, wout = precompute_latitudes(nlat_out, grid=grid_out) + lons_in = precompute_longitudes(nlon_in) + + if transpose_normalization: + quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0 + else: + quad_weights = win.reshape(-1, 1) / nlon_in / 2.0 + + theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff + + out_idx = [] + out_vals = [] + + beta = lons_in + gamma = lats_in.reshape(-1, 1) + + cbeta = torch.cos(beta) + sbeta = torch.sin(beta) + cgamma = torch.cos(gamma) + sgamma = torch.sin(gamma) + + out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device) + out_roff[0] = 0 + for t in range(nlat_out): + alpha = -lats_out[t] + + x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha) + y = sbeta * sgamma + z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma + + norm = torch.sqrt(x * x + y * y + z * z) + x = x / norm + y = y / norm + z = z / norm + + theta = torch.arccos(z) + phi = torch.arctan2(y, x) + phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) + + iidx, vals = filter_basis.compute_support_vals( + theta, phi, r_cutoff=theta_cutoff_eff + ) + + idx = torch.stack( + [ + iidx[:, 0], + t * torch.ones_like(iidx[:, 0]), + iidx[:, 1] * nlon_in + iidx[:, 2], + ], + dim=0, + ) + + out_idx.append(idx) + out_vals.append(vals) + out_roff[t + 1] = out_roff[t] + iidx.shape[0] + + out_idx = torch.cat(out_idx, dim=-1) + out_vals = torch.cat(out_vals, dim=-1) + + out_vals = _normalize_convolution_tensor_s2( + out_idx, + out_vals, + in_shape, + out_shape, + kernel_size, + quad_weights, + transpose_normalization=transpose_normalization, + basis_norm_mode=basis_norm_mode, + merge_quadrature=merge_quadrature, + ) + + out_idx = out_idx.contiguous() + out_vals = out_vals.to(dtype=torch.float32).contiguous() + + return out_idx, out_vals, out_roff + + +class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): + """Abstract base class for discrete-continuous convolutions.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_shape: int | tuple[int, ...], + basis_type: str = "piecewise linear", + groups: int = 1, + bias: bool = True, + ): + super().__init__() + + self.kernel_shape = kernel_shape + self.filter_basis = _filter_basis_module.get_filter_basis( + kernel_shape=kernel_shape, basis_type=basis_type + ) + + self.groups = groups + + if in_channels % self.groups != 0: + raise ValueError( + "Error, the number of input channels has to be an integer " + "multiple of the group size" + ) + if out_channels % self.groups != 0: + raise ValueError( + "Error, the number of output channels has to be an integer " + "multiple of the group size" + ) + self.groupsize = in_channels // self.groups + scale = math.sqrt(1.0 / self.groupsize / self.kernel_size) + self.weight = nn.Parameter( + scale * torch.randn(out_channels, self.groupsize, self.kernel_size) + ) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + @property + def kernel_size(self): + return self.filter_basis.kernel_size + + @abc.abstractmethod + def forward(self, x: torch.Tensor): + raise NotImplementedError + + +class DiscreteContinuousConvS2(DiscreteContinuousConv): + """Discrete-continuous (DISCO) convolution on the 2-Sphere. + + Uses an FFT-based contraction for efficient computation. Forked from + torch-harmonics; see Ocampo, Price, McEwen, *Scalable and equivariant + spherical CNNs by discrete-continuous (DISCO) convolutions*, ICLR 2023. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + in_shape: tuple[int, int], + out_shape: tuple[int, int], + kernel_shape: int | tuple[int, ...], + basis_type: str = "piecewise linear", + basis_norm_mode: str = "mean", + groups: int = 1, + grid_in: str = "equiangular", + grid_out: str = "equiangular", + bias: bool = True, + theta_cutoff: float | None = None, + ): + super().__init__( + in_channels, out_channels, kernel_shape, basis_type, groups, bias + ) + + self.nlat_in, self.nlon_in = in_shape + self.nlat_out, self.nlon_out = out_shape + + assert self.nlon_in % self.nlon_out == 0 + + if theta_cutoff is None: + self.theta_cutoff = torch.pi / float(self.nlat_out - 1) + else: + self.theta_cutoff = theta_cutoff + + if self.theta_cutoff <= 0.0: + raise ValueError("Error, theta_cutoff has to be positive.") + + idx, vals, _ = _precompute_convolution_tensor_s2( + in_shape, + out_shape, + self.filter_basis, + grid_in=grid_in, + grid_out=grid_out, + theta_cutoff=self.theta_cutoff, + transpose_normalization=False, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + self.register_buffer("psi_ker_idx", ker_idx, persistent=False) + self.register_buffer("psi_row_idx", row_idx, persistent=False) + self.register_buffer("psi_col_idx", col_idx, persistent=False) + self.register_buffer("psi_vals", vals, persistent=False) + + # Precompute banded FFT of psi for FFT-based contraction + psi_sparse = _get_psi( + self.kernel_size, + self.psi_idx, + self.psi_vals, + self.nlat_in, + self.nlon_in, + self.nlat_out, + self.nlon_out, + ) + psi_fft_conj, gather_idx = _precompute_psi_banded( + psi_sparse, self.nlat_in, self.nlon_in + ) + self.register_buffer("psi_fft_conj", psi_fft_conj, persistent=False) + self.register_buffer("psi_gather_idx", gather_idx, persistent=False) + + def extra_repr(self): + return ( + f"in_shape={(self.nlat_in, self.nlon_in)}, " + f"out_shape={(self.nlat_out, self.nlon_out)}, " + f"in_chans={self.groupsize * self.groups}, " + f"out_chans={self.weight.shape[0]}, " + f"filter_basis={self.filter_basis}, " + f"kernel_shape={self.kernel_shape}, " + f"theta_cutoff={self.theta_cutoff}, " + f"groups={self.groups}" + ) + + @property + def psi_idx(self): + return torch.stack( + [self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0 + ).contiguous() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = _disco_s2_contraction_fft( + x, + self.psi_fft_conj.to(x.device), + self.psi_gather_idx.to(x.device), + self.nlon_out, + ) + + B, C, K, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize, K, H, W) + + out = torch.einsum( + "bgckxy,gock->bgoxy", + x, + self.weight.reshape( + self.groups, -1, self.weight.shape[1], self.weight.shape[2] + ), + ).contiguous() + out = out.reshape(B, -1, H, W) + + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1, 1) + + return out diff --git a/fme/core/disco/_disco_utils.py b/fme/core/disco/_disco_utils.py new file mode 100644 index 000000000..8f35ca7a9 --- /dev/null +++ b/fme/core/disco/_disco_utils.py @@ -0,0 +1,136 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause +# +# Modifications: stripped down to FFT-based contraction only, removing +# CUDA kernel wrappers and torch sparse-matrix fallbacks. + + +import torch +from torch_harmonics.fft import irfft, rfft + + +def _get_psi( + kernel_size: int, + psi_idx: torch.Tensor, + psi_vals: torch.Tensor, + nlat_in: int, + nlon_in: int, + nlat_out: int, + nlon_out: int, + nlat_in_local: int | None = None, + nlat_out_local: int | None = None, + semi_transposed: bool | None = False, +): + """Creates a sparse tensor for spherical harmonic convolution operations.""" + nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in + nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out + + if semi_transposed: + tout = psi_idx[2] // nlon_out + pout = psi_idx[2] % nlon_out + pout = nlon_out - 1 - pout + tin = psi_idx[1] + idx = torch.stack([psi_idx[0], tout, tin * nlon_out + pout], dim=0) + psi = torch.sparse_coo_tensor( + idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_out) + ).coalesce() + else: + psi = torch.sparse_coo_tensor( + psi_idx, + psi_vals, + size=(kernel_size, nlat_out_local, nlat_in_local * nlon_in), + ).coalesce() + return psi + + +def _precompute_psi_banded(psi_sparse: torch.Tensor, nlat_in: int, nlon: int): + """Build a banded dense representation of psi directly from sparse COO data. + + Instead of densifying the full (K, nlat_out, nlat_in, nlon) tensor, this + only stores the contiguous band of input latitudes that have nonzero entries + for each output latitude, reducing memory by ~nlat_in/band_width. + + Returns: + ------- + psi_banded_fft_conj : (K, nlat_out, max_bw, nfreq) complex tensor + gather_idx : (nlat_out, max_bw) long tensor of input latitude indices + """ + K, nlat_out, _ = psi_sparse.shape + psi = psi_sparse.coalesce() + indices = psi.indices() # (3, nnz) + values = psi.values() # (nnz,) + + ker_idx = indices[0] + row_idx = indices[1] # output lat + col_idx = indices[2] # input_lat * nlon + lon + input_lat = col_idx // nlon + input_lon = col_idx % nlon + + # Find min/max input lat per output lat (across all kernel indices) + lat_min = torch.full((nlat_out,), nlat_in, dtype=torch.long) + lat_max = torch.full((nlat_out,), 0, dtype=torch.long) + lat_min.scatter_reduce_(0, row_idx, input_lat, reduce="amin") + lat_max.scatter_reduce_(0, row_idx, input_lat, reduce="amax") + + # Handle empty rows + empty = lat_min >= nlat_in + lat_min[empty] = 0 + lat_max[empty] = 0 + + max_bw = (lat_max - lat_min + 1).max().item() + + # Build banded tensor from sparse entries (no full densification) + psi_banded = torch.zeros(K, nlat_out, max_bw, nlon, dtype=values.dtype) + banded_lat = input_lat - lat_min[row_idx] + psi_banded[ker_idx, row_idx, banded_lat, input_lon] = values + + # Precompute FFT and gather index + psi_banded_fft_conj = rfft(psi_banded, dim=-1).conj() + gather_idx = lat_min.unsqueeze(1) + torch.arange(max_bw).unsqueeze(0) + gather_idx = gather_idx.clamp(max=nlat_in - 1) + + return psi_banded_fft_conj, gather_idx + + +def _disco_s2_contraction_fft( + x: torch.Tensor, + psi_fft_conj: torch.Tensor, + gather_idx: torch.Tensor, + nlon_out: int, +): + """FFT-based DISCO S2 contraction using banded psi representation. + + Parameters + ---------- + x : (B, C, nlat_in, nlon_in) + psi_fft_conj : (K, nlat_out, bw, nfreq) + gather_idx : (nlat_out, bw) + nlon_out : int + + Returns: + ------- + (B, C, K, nlat_out, nlon_out) + """ + batch_size, n_chans, nlat_in, nlon_in = x.shape + kernel_size, nlat_out, bw, nfreq = psi_fft_conj.shape + pscale = nlon_in // nlon_out + + # FFT of input along longitude + X_f = rfft(x.to(torch.float32), dim=-1) # (B, C, nlat_in, nfreq) + X_f = X_f.reshape(batch_size * n_chans, nlat_in, nfreq) + + # Gather relevant input lats for each output lat + X_f_gathered = X_f[:, gather_idx, :] # (B*C, nlat_out, bw, nfreq) + + # Cross-correlate: einsum over band width and frequency, then irfft + Y_f = torch.einsum("kowf,bowf->bkof", psi_fft_conj, X_f_gathered) + + # Inverse FFT + y = irfft(Y_f, n=nlon_in, dim=-1) # (B*C, K, nlat_out, nlon_in) + + # Subsample for stride + y = y[..., ::pscale] # (B*C, K, nlat_out, nlon_out) + + y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out).contiguous() + return y.to(x.dtype) diff --git a/fme/core/disco/test_disco.py b/fme/core/disco/test_disco.py new file mode 100644 index 000000000..b190f3a5f --- /dev/null +++ b/fme/core/disco/test_disco.py @@ -0,0 +1,121 @@ +import torch +import torch_harmonics as th + +from fme.core.disco import DiscreteContinuousConvS2 + + +def test_forward_shape(): + """Basic smoke test: output has the expected shape.""" + img_shape = (16, 32) + in_channels = 4 + out_channels = 8 + conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + grid_in="equiangular", + grid_out="equiangular", + bias=True, + theta_cutoff=0.5, + ) + x = torch.randn(2, in_channels, *img_shape) + with torch.no_grad(): + y = conv(x) + assert y.shape == (2, out_channels, *img_shape) + + +def test_matches_torch_harmonics_reference(): + """Output matches the torch-harmonics sparse-matrix implementation.""" + torch.manual_seed(0) + img_shape = (16, 32) + in_channels = 4 + out_channels = 4 + conv_kwargs = dict( + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + groups=1, + grid_in="equiangular", + grid_out="equiangular", + bias=False, + theta_cutoff=0.5, + ) + + fft_conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + **conv_kwargs, # type: ignore[arg-type] + ) + ref_conv = th.DiscreteContinuousConvS2( + in_channels, + out_channels, + **conv_kwargs, # type: ignore[arg-type] + use_fft_contraction=False, + ) + + # Copy weights from fft_conv to ref_conv so they use the same parameters + ref_conv.weight.data.copy_(fft_conv.weight.data) + + x = torch.randn(2, in_channels, *img_shape) + with torch.no_grad(): + y_fft = fft_conv(x) + y_ref = ref_conv(x) + + torch.testing.assert_close(y_fft, y_ref, atol=1e-4, rtol=1e-4) + + +def test_grouped_convolution(): + """Test that grouped convolutions work.""" + img_shape = (16, 32) + in_channels = 8 + out_channels = 8 + groups = 4 + conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + groups=groups, + grid_in="equiangular", + grid_out="equiangular", + bias=False, + theta_cutoff=0.5, + ) + x = torch.randn(2, in_channels, *img_shape) + with torch.no_grad(): + y = conv(x) + assert y.shape == (2, out_channels, *img_shape) + + +def test_backward_pass(): + """Gradients flow through the FFT-based convolution.""" + img_shape = (16, 32) + in_channels = 4 + out_channels = 4 + conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + grid_in="equiangular", + grid_out="equiangular", + bias=True, + theta_cutoff=0.5, + ) + x = torch.randn(2, in_channels, *img_shape) + y = conv(x) + y.sum().backward() + assert conv.weight.grad is not None + assert conv.bias.grad is not None diff --git a/fme/core/distributed/non_distributed.py b/fme/core/distributed/non_distributed.py index 4814efb68..01a44d60f 100644 --- a/fme/core/distributed/non_distributed.py +++ b/fme/core/distributed/non_distributed.py @@ -115,7 +115,9 @@ def get_isht( return th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid).float() def get_disco_conv_s2(self, *args, **kwargs) -> nn.Module: - return th.DiscreteContinuousConvS2(*args, **kwargs).float() + from fme.core.disco import DiscreteContinuousConvS2 + + return DiscreteContinuousConvS2(*args, **kwargs).float() def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: return tensor diff --git a/fme/core/distributed/torch_distributed.py b/fme/core/distributed/torch_distributed.py index e5f924edd..b802c307d 100644 --- a/fme/core/distributed/torch_distributed.py +++ b/fme/core/distributed/torch_distributed.py @@ -207,7 +207,9 @@ def get_isht( return th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid).float() def get_disco_conv_s2(self, *args, **kwargs) -> nn.Module: - return th.DiscreteContinuousConvS2(*args, **kwargs).float() + from fme.core.disco import DiscreteContinuousConvS2 + + return DiscreteContinuousConvS2(*args, **kwargs).float() def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: return tensor From 4d0dfe27e8d91995fd514c3abf2b19fa5908f1a7 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 25 Mar 2026 22:23:18 +0000 Subject: [PATCH 25/31] Remove torch-harmonics dependency from fme/core/disco Fork the remaining torch-harmonics utilities (filter_basis, quadrature, cache, fft) into fme/core/disco/ so the DISCO convolution module is fully self-contained. torch-harmonics is still used in tests for reference comparison. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/disco/_cache.py | 24 +++ fme/core/disco/_convolution.py | 8 +- fme/core/disco/_disco_utils.py | 3 +- fme/core/disco/_fft.py | 51 ++++++ fme/core/disco/_filter_basis.py | 301 ++++++++++++++++++++++++++++++++ fme/core/disco/_quadrature.py | 101 +++++++++++ 6 files changed, 483 insertions(+), 5 deletions(-) create mode 100644 fme/core/disco/_cache.py create mode 100644 fme/core/disco/_fft.py create mode 100644 fme/core/disco/_filter_basis.py create mode 100644 fme/core/disco/_quadrature.py diff --git a/fme/core/disco/_cache.py b/fme/core/disco/_cache.py new file mode 100644 index 000000000..26f663b47 --- /dev/null +++ b/fme/core/disco/_cache.py @@ -0,0 +1,24 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +from copy import deepcopy + + +def lru_cache(maxsize=20, typed=False, copy=False): + """LRU cache decorator with optional deep copying of cached results.""" + + def decorator(f): + cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f) + + def wrapper(*args, **kwargs): + res = cached_func(*args, **kwargs) + if copy: + return deepcopy(res) + else: + return res + + return wrapper + + return decorator diff --git a/fme/core/disco/_convolution.py b/fme/core/disco/_convolution.py index ef7a6d540..a10824736 100644 --- a/fme/core/disco/_convolution.py +++ b/fme/core/disco/_convolution.py @@ -11,12 +11,12 @@ import torch import torch.nn as nn -from torch_harmonics import filter_basis as _filter_basis_module -from torch_harmonics.cache import lru_cache -from torch_harmonics.filter_basis import FilterBasis -from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes +from . import _filter_basis as _filter_basis_module +from ._cache import lru_cache from ._disco_utils import _disco_s2_contraction_fft, _get_psi, _precompute_psi_banded +from ._filter_basis import FilterBasis +from ._quadrature import precompute_latitudes, precompute_longitudes def _normalize_convolution_tensor_s2( diff --git a/fme/core/disco/_disco_utils.py b/fme/core/disco/_disco_utils.py index 8f35ca7a9..381a25d54 100644 --- a/fme/core/disco/_disco_utils.py +++ b/fme/core/disco/_disco_utils.py @@ -7,7 +7,8 @@ import torch -from torch_harmonics.fft import irfft, rfft + +from ._fft import irfft, rfft def _get_psi( diff --git a/fme/core/disco/_fft.py b/fme/core/disco/_fft.py new file mode 100644 index 000000000..5632908bb --- /dev/null +++ b/fme/core/disco/_fft.py @@ -0,0 +1,51 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.fft as fft +import torch.nn as nn + + +def _pad_dim_right( + x: torch.Tensor, dim: int, target_size: int, value: float = 0.0 +) -> torch.Tensor: + """Pad tensor along a single dimension to target_size (right-side only).""" + ndim = x.ndim + dim = dim if dim >= 0 else ndim + dim + pad_amount = target_size - x.shape[dim] + pad_spec = [0] * (2 * ndim) + pad_spec[(ndim - 1 - dim) * 2 + 1] = pad_amount + return nn.functional.pad(x, tuple(pad_spec), value=value) + + +def rfft( + x: torch.Tensor, nmodes: int | None = None, dim: int = -1, **kwargs +) -> torch.Tensor: + """Real FFT with correct padding/truncation of modes.""" + if "n" in kwargs: + raise ValueError("The 'n' argument is not allowed. Use 'nmodes' instead.") + + x = fft.rfft(x, dim=dim, **kwargs) + + if nmodes is not None and nmodes > x.shape[dim]: + x = _pad_dim_right(x, dim, nmodes, value=0.0) + elif nmodes is not None and nmodes < x.shape[dim]: + x = x.narrow(dim, 0, nmodes) + + return x + + +def irfft( + x: torch.Tensor, n: int | None = None, dim: int = -1, **kwargs +) -> torch.Tensor: + """Inverse real FFT with Hermitian symmetry enforcement.""" + if n is None: + n = 2 * (x.size(dim) - 1) + + x[..., 0].imag = 0.0 + if (n % 2 == 0) and (n // 2 < x.size(dim)): + x[..., n // 2].imag = 0.0 + + x = fft.irfft(x, n=n, dim=dim, **kwargs) + return x diff --git a/fme/core/disco/_filter_basis.py b/fme/core/disco/_filter_basis.py new file mode 100644 index 000000000..299eba9b0 --- /dev/null +++ b/fme/core/disco/_filter_basis.py @@ -0,0 +1,301 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause + +import abc +import math + +import torch + +from ._cache import lru_cache + + +def _circle_dist(x1: torch.Tensor, x2: torch.Tensor): + return torch.minimum( + torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)) + ) + + +def _log_factorial(x: torch.Tensor): + return torch.lgamma(x + 1) + + +def _factorial(x: torch.Tensor): + return torch.exp(_log_factorial(x)) + + +class FilterBasis(metaclass=abc.ABCMeta): + """Abstract base class for a filter basis.""" + + kernel_shape: int | list[int] | tuple[int, ...] + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + self.kernel_shape = kernel_shape + + def __repr__(self): + class_name = self.__class__.__name__ + if hasattr(self, "extra_repr"): + return f"{class_name}({self.extra_repr()})" + else: + return f"{class_name}()" + + def extra_repr(self): + return f"kernel_shape={self.kernel_shape}" + + @property + @abc.abstractmethod + def kernel_size(self): + raise NotImplementedError + + @abc.abstractmethod + def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): + raise NotImplementedError + + +@lru_cache(typed=True, copy=False) +def get_filter_basis( + kernel_shape: int | tuple[int, ...], basis_type: str +) -> FilterBasis: + """Factory function to generate the appropriate filter basis.""" + if basis_type == "piecewise linear": + return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "morlet": + return MorletFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "zernike": + return ZernikeFilterBasis(kernel_shape=kernel_shape) + else: + raise ValueError(f"Unknown basis_type {basis_type}") + + +class PiecewiseLinearFilterBasis(FilterBasis): + """Tensor-product basis on a disk from piecewise linear basis functions.""" + + kernel_shape: list[int] + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, int): + kernel_shape = [kernel_shape] + if len(kernel_shape) == 1: + kernel_shape = [kernel_shape[0], 1] + elif len(kernel_shape) != 2: + raise ValueError( + f"expected kernel_shape to be of length 1 or 2 " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self): + return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[ + 0 + ] % 2 + + def _compute_support_vals_isotropic( + self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + nr = self.kernel_shape[0] + dr = 2 * r_cutoff / (nr + 1) + + if nr % 2 == 1: + ir = ikernel * dr + else: + ir = (ikernel + 0.5) * dr + + iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) + vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr + return iidx, vals + + def _compute_support_vals_anisotropic( + self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + nr = self.kernel_shape[0] + nphi = self.kernel_shape[1] + dr = 2 * r_cutoff / (nr + 1) + dphi = 2.0 * math.pi / nphi + + if nr % 2 == 1: + ir = ((ikernel - 1) // nphi + 1) * dr + iphi = ((ikernel - 1) % nphi) * dphi - math.pi + else: + ir = (ikernel // nphi + 0.5) * dr + iphi = (ikernel % nphi) * dphi - math.pi + + if nr % 2 == 1: + cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) + cond_phi = (ikernel == 0) | (_circle_dist(phi, iphi).abs() <= dphi) + iidx = torch.argwhere(cond_r & cond_phi) + dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() + dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) + vals = 1 - dist_r / dr + vals *= torch.where((iidx[:, 0] > 0), (1 - dist_phi / dphi), 1.0) + else: + rn = -r + phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi) + cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) + cond_phi = _circle_dist(phi, iphi).abs() <= dphi + cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) + cond_phin = _circle_dist(phin, iphi) <= dphi + iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) + + dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() + dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) + dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() + dist_phin = _circle_dist( + phin[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0] + ) + vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) + vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) + valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) + valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * ( + 1 - dist_phin / dphi + ) + vals += valsn + + return iidx, vals + + def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): + if self.kernel_shape[1] > 1: + return self._compute_support_vals_anisotropic(r, phi, r_cutoff=r_cutoff) + else: + return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff) + + +class MorletFilterBasis(FilterBasis): + """Morlet-style filter basis on the disk.""" + + kernel_shape: list[int] + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, int): + kernel_shape = [kernel_shape, kernel_shape] + if len(kernel_shape) != 2: + raise ValueError( + f"expected kernel_shape to be of length 2 " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self): + return self.kernel_shape[0] * self.kernel_shape[1] + + def hann_window(self, r: torch.Tensor, width: float = 1.0): + return torch.cos(0.5 * torch.pi * r / width) ** 2 + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 1.0, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + nkernel = ikernel % self.kernel_shape[1] + mkernel = ikernel // self.kernel_shape[1] + + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + phi = phi[iidx[:, 1], iidx[:, 2]] + x = r * torch.sin(phi) + y = r * torch.cos(phi) + n = nkernel[iidx[:, 0], 0, 0] + m = mkernel[iidx[:, 0], 0, 0] + + harmonic = torch.where( + n % 2 == 1, + torch.sin(torch.ceil(n / 2) * math.pi * x / width), + torch.cos(torch.ceil(n / 2) * math.pi * x / width), + ) + harmonic *= torch.where( + m % 2 == 1, + torch.sin(torch.ceil(m / 2) * math.pi * y / width), + torch.cos(torch.ceil(m / 2) * math.pi * y / width), + ) + + vals = self.hann_window(r, width=width) * harmonic + return iidx, vals + + +class ZernikeFilterBasis(FilterBasis): + """Zernike polynomial basis defined on the disk.""" + + kernel_shape: int + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, list | tuple): + kernel_shape = kernel_shape[0] + if not isinstance(kernel_shape, int): + raise ValueError( + f"expected kernel_shape to be an integer " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self): + return (self.kernel_shape * (self.kernel_shape + 1)) // 2 + + def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor): + out = torch.zeros_like(r) + bound = (n - m) // 2 + 1 + max_bound = bound.max().item() + + for k in range(max_bound): + inc = ( + (-1) ** k + * _factorial(n - k) + * r ** (n - 2 * k) + / ( + math.factorial(k) + * _factorial((n + m) // 2 - k) + * _factorial((n - m) // 2 - k) + ) + ) + out += torch.where(k < bound, inc, 0.0) + return out + + def zernikepoly( + self, + r: torch.Tensor, + phi: torch.Tensor, + n: torch.Tensor, + l: torch.Tensor, # noqa: E741 + ): + m = 2 * l - n + return torch.where( + m < 0, + self.zernikeradial(r, n, -m) * torch.sin(m * phi), + self.zernikeradial(r, n, m) * torch.cos(m * phi), + ) + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 0.25, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + nshifts = torch.arange(self.kernel_shape, device=r.device) + nshifts = (nshifts + 1) * nshifts // 2 + nkernel = torch.searchsorted(nshifts, ikernel, right=True) - 1 + lkernel = ikernel - nshifts[nkernel] + + r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + phi = phi[iidx[:, 1], iidx[:, 2]] + n = nkernel[iidx[:, 0], 0, 0] + l = lkernel[iidx[:, 0], 0, 0] # noqa: E741 + + vals = self.zernikepoly(r, phi, n, l) + return iidx, vals diff --git a/fme/core/disco/_quadrature.py b/fme/core/disco/_quadrature.py new file mode 100644 index 000000000..dac51540d --- /dev/null +++ b/fme/core/disco/_quadrature.py @@ -0,0 +1,101 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause +# +# Subset: only the functions needed by DISCO convolution precomputation. + +import math + +import numpy as np +import torch + +from ._cache import lru_cache + + +def _trapezoidal_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + xlg = torch.as_tensor(np.linspace(a, b, n, endpoint=True)) + wlg = (b - a) / n * torch.ones(n, requires_grad=False) + wlg[0] *= 0.5 + wlg[-1] *= 0.5 + return xlg, wlg + + +def _legendre_gauss_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + xlg, wlg = np.polynomial.legendre.leggauss(n) + xlg = torch.as_tensor(xlg).clone() + wlg = torch.as_tensor(wlg).clone() + xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 + wlg = wlg * (b - a) * 0.5 + return xlg, wlg + + +def _clenshaw_curtiss_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + assert n > 1 + + tcc = torch.cos( + torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False) + ) + + if n == 2: + wcc = torch.as_tensor([1.0, 1.0], dtype=torch.float64) + else: + n1 = n - 1 + N = torch.arange(1, n1, 2, dtype=torch.float64) + ll = len(N) + m = n1 - ll + + v = torch.cat( + [ + 2 / N / (N - 2), + 1 / N[-1:], + torch.zeros(m, dtype=torch.float64, requires_grad=False), + ] + ) + v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,)) + + g0 = -torch.ones(n1, dtype=torch.float64, requires_grad=False) + g0[ll] = g0[ll] + n1 + g0[m] = g0[m] + n1 + g = g0 / (n1**2 - 1 + (n1 % 2)) + wcc = torch.fft.ifft(v + g).real + wcc = torch.cat((wcc, wcc[:1])) + + tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 + wcc = wcc * (b - a) * 0.5 + return tcc, wcc + + +def _precompute_quadrature_weights( + n: int, grid: str = "equiangular", a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + if grid == "equidistant": + return _trapezoidal_weights(n, a=a, b=b) + elif grid == "legendre-gauss": + return _legendre_gauss_weights(n, a=a, b=b) + elif grid == "equiangular": + return _clenshaw_curtiss_weights(n, a=a, b=b) + else: + raise ValueError(f"Unknown grid type {grid}") + + +@lru_cache(typed=True, copy=True) +def precompute_longitudes(nlon: int) -> torch.Tensor: + return torch.linspace( + 0, 2 * math.pi, nlon + 1, dtype=torch.float64, requires_grad=False + )[:-1] + + +@lru_cache(typed=True, copy=True) +def precompute_latitudes( + nlat: int, grid: str = "equiangular" +) -> tuple[torch.Tensor, torch.Tensor]: + xlg, wlg = _precompute_quadrature_weights(nlat, grid=grid, a=-1.0, b=1.0) + lats = torch.flip(torch.arccos(xlg), dims=(0,)).clone() + wlg = torch.flip(wlg, dims=(0,)).clone() + return lats, wlg From f84aae39ecde8ab4278f57d0da2799dc6b72b959 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 25 Mar 2026 22:26:43 +0000 Subject: [PATCH 26/31] Replace isotropic Morlet monkey-patch with native filter basis Move IsotropicMorletFilterBasis from the torch-harmonics monkey-patch module into fme/core/disco/_filter_basis.py and register it directly in get_filter_basis. Remove the now-unnecessary isotropic_filter.py and its import from localnet.py. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/disco/_filter_basis.py | 62 +++++++++++ .../conditional_sfno/isotropic_filter.py | 100 ------------------ fme/core/models/conditional_sfno/localnet.py | 3 - 3 files changed, 62 insertions(+), 103 deletions(-) delete mode 100644 fme/core/models/conditional_sfno/isotropic_filter.py diff --git a/fme/core/disco/_filter_basis.py b/fme/core/disco/_filter_basis.py index 299eba9b0..372c9ca1a 100644 --- a/fme/core/disco/_filter_basis.py +++ b/fme/core/disco/_filter_basis.py @@ -61,6 +61,8 @@ def get_filter_basis( return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape) elif basis_type == "morlet": return MorletFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "isotropic morlet": + return IsotropicMorletFilterBasis(kernel_shape=kernel_shape) elif basis_type == "zernike": return ZernikeFilterBasis(kernel_shape=kernel_shape) else: @@ -222,6 +224,66 @@ def compute_support_vals( return iidx, vals +class IsotropicMorletFilterBasis(FilterBasis): + """Morlet-style filter basis using only radial modes. + + Each basis function is a product of a Hann radial window and a 1-D + Fourier harmonic in the normalised radial coordinate ``r / r_cutoff``. + Because none of the basis functions depend on the azimuthal angle + ``phi``, any learned linear combination is guaranteed to be isotropic + (radially symmetric). + + ``kernel_shape`` is a single integer giving the number of radial modes. + If a tuple is provided, only the first element is used. + """ + + kernel_shape: int + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, list | tuple): + kernel_shape = kernel_shape[0] + if not isinstance(kernel_shape, int): + raise ValueError( + f"expected kernel_shape to be an integer " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self) -> int: + return self.kernel_shape + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 1.0, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + r_norm = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + n = ikernel[iidx[:, 0], 0, 0] + + # Radial Fourier modes: cos/sin pattern identical to Morlet but in r. + harmonic = torch.where( + n % 2 == 1, + torch.sin(torch.ceil(n / 2) * math.pi * r_norm / width), + torch.cos(torch.ceil(n / 2) * math.pi * r_norm / width), + ) + + # Hann radial envelope + window = torch.cos(0.5 * torch.pi * r_norm / width) ** 2 + vals = window * harmonic + + return iidx, vals + + class ZernikeFilterBasis(FilterBasis): """Zernike polynomial basis defined on the disk.""" diff --git a/fme/core/models/conditional_sfno/isotropic_filter.py b/fme/core/models/conditional_sfno/isotropic_filter.py deleted file mode 100644 index df9e2c5ec..000000000 --- a/fme/core/models/conditional_sfno/isotropic_filter.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Isotropic Morlet filter basis and monkey-patch for torch-harmonics. - -This module defines an isotropic variant of the Morlet filter basis where -all basis functions depend only on radial distance, not on azimuthal angle. -It monkey-patches ``torch_harmonics.filter_basis.get_filter_basis`` (and the -copy already imported into ``torch_harmonics.disco.convolution``) so that -the basis type ``"isotropic morlet"`` is available everywhere DISCO -convolutions are constructed. - -This is a temporary measure until the change is contributed upstream or the -dependency is forked. -""" - -import math - -import torch -from torch_harmonics import filter_basis as _fb -from torch_harmonics.disco import convolution as _conv - - -class IsotropicMorletFilterBasis(_fb.FilterBasis): - """Morlet-style filter basis using only radial modes. - - Each basis function is a product of a Hann radial window and a 1-D - Fourier harmonic in the normalised radial coordinate ``r / r_cutoff``. - Because none of the basis functions depend on the azimuthal angle - ``phi``, any learned linear combination is guaranteed to be isotropic - (radially symmetric). - - ``kernel_shape`` is a single integer giving the number of radial modes. - If a tuple is provided, only the first element is used. - """ - - def __init__( - self, - kernel_shape: int | tuple[int] | tuple[int, int], - ): - if isinstance(kernel_shape, list | tuple): - kernel_shape = kernel_shape[0] - if not isinstance(kernel_shape, int): - raise ValueError( - f"expected kernel_shape to be an integer but got " - f"{kernel_shape} instead." - ) - super().__init__(kernel_shape=kernel_shape) - - @property - def kernel_size(self) -> int: - return self.kernel_shape - - def compute_support_vals( - self, - r: torch.Tensor, - phi: torch.Tensor, - r_cutoff: float, - width: float = 1.0, - ): - ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) - - iidx = torch.argwhere( - (r <= r_cutoff) - & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) - ) - - r_norm = r[iidx[:, 1], iidx[:, 2]] / r_cutoff - n = ikernel[iidx[:, 0], 0, 0] - - # Radial Fourier modes: cos/sin pattern identical to Morlet but in r. - harmonic = torch.where( - n % 2 == 1, - torch.sin(torch.ceil(n / 2) * math.pi * r_norm / width), - torch.cos(torch.ceil(n / 2) * math.pi * r_norm / width), - ) - - # Hann radial envelope - window = torch.cos(0.5 * torch.pi * r_norm / width) ** 2 - vals = window * harmonic - - return iidx, vals - - -# --------------------------------------------------------------------------- -# Monkey-patch get_filter_basis -# --------------------------------------------------------------------------- - -_original_get_filter_basis = _fb.get_filter_basis - - -@_fb.lru_cache(typed=True, copy=False) -def _patched_get_filter_basis( - kernel_shape: int | tuple[int] | tuple[int, int], - basis_type: str, -) -> _fb.FilterBasis: - if basis_type == "isotropic morlet": - return IsotropicMorletFilterBasis(kernel_shape=kernel_shape) - return _original_get_filter_basis(kernel_shape, basis_type) - - -_fb.get_filter_basis = _patched_get_filter_basis -_conv.get_filter_basis = _patched_get_filter_basis diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py index e0806a5a8..96463a478 100644 --- a/fme/core/models/conditional_sfno/localnet.py +++ b/fme/core/models/conditional_sfno/localnet.py @@ -9,9 +9,6 @@ from fme.core.distributed import Distributed from .initialization import trunc_normal_ -from .isotropic_filter import ( - IsotropicMorletFilterBasis as _IsotropicMorletFilterBasis, # noqa: F401 (monkey-patches torch-harmonics on import) -) from .layers import MLP, ConditionalLayerNorm, Context, ContextConfig from .lora import LoRAConv2d from .sfnonet import DiscreteContinuousConvS2, NoLayerNorm, _compute_cutoff_radius From 138c72d5d2eb945f4bc2d023cd92fb29e0e065dc Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 26 Mar 2026 14:52:15 +0000 Subject: [PATCH 27/31] Fix DISCO test compatibility with torch-harmonics 0.8.0 Remove use_fft_contraction kwarg that only exists in torch-harmonics 0.8.1+, while CI pins 0.8.0. Omitting it gives identical behavior on both versions (sparse-matrix path is the default). Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/disco/test_disco.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fme/core/disco/test_disco.py b/fme/core/disco/test_disco.py index b190f3a5f..5a1729560 100644 --- a/fme/core/disco/test_disco.py +++ b/fme/core/disco/test_disco.py @@ -56,7 +56,6 @@ def test_matches_torch_harmonics_reference(): in_channels, out_channels, **conv_kwargs, # type: ignore[arg-type] - use_fft_contraction=False, ) # Copy weights from fft_conv to ref_conv so they use the same parameters From a0243401d2fccc64c762cb834838821421416c9b Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 26 Mar 2026 16:19:51 +0000 Subject: [PATCH 28/31] Add arxiv reference to AnkurLocalNet documentation Reference: https://arxiv.org/html/2602.16090v1 Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/ace/registry/local_net.py | 2 ++ fme/core/models/conditional_sfno/ankur.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/fme/ace/registry/local_net.py b/fme/ace/registry/local_net.py index 1f4753d59..2e5b8fb51 100644 --- a/fme/ace/registry/local_net.py +++ b/fme/ace/registry/local_net.py @@ -52,6 +52,8 @@ class AnkurLocalNetBuilder(ModuleConfig): A simple 3-hidden-layer MLP that optionally uses a DISCO convolution for the first layer and a learned positional embedding. + Reference: https://arxiv.org/html/2602.16090v1 + Attributes: embed_dim: Dimension of the hidden layers. use_disco_encoder: Whether to use a DISCO convolution for the first diff --git a/fme/core/models/conditional_sfno/ankur.py b/fme/core/models/conditional_sfno/ankur.py index c90b7e69e..5cf94203f 100644 --- a/fme/core/models/conditional_sfno/ankur.py +++ b/fme/core/models/conditional_sfno/ankur.py @@ -19,6 +19,8 @@ class AnkurLocalNetConfig: Replicates the diagnostic MLP architecture from Ankur's ColumnDiagnosticSphericalFourierNeuralOperatorNet. + Reference: https://arxiv.org/html/2602.16090v1 + Attributes: type: Discriminator tag for union deserialization. embed_dim: Dimension of the hidden layers. @@ -111,6 +113,8 @@ class AnkurLocalNet(nn.Module): DISCO convolution for the first layer and a learned positional embedding. This is a drop-in replacement for LocalNet with the same forward signature. + Reference: https://arxiv.org/html/2602.16090v1 + Args: params: Model configuration. See ``AnkurLocalNetConfig`` for details. img_shape: Spatial dimensions (lat, lon) of the input data. From e2b75592f6f3ef55ecdeb555443f54c0f08977ee Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 26 Mar 2026 17:29:15 +0000 Subject: [PATCH 29/31] Fix DDP crash on integer DISCO index buffers DDP's _broadcast_coalesced cannot handle non-float dtypes. Store integer index tensors (psi_ker_idx, psi_row_idx, psi_col_idx, psi_gather_idx) as plain attributes instead of registered buffers, with an _apply override to ensure they follow device transfers. These tensors are deterministic across ranks so sync is unnecessary. Co-Authored-By: Claude Opus 4.6 (1M context) --- fme/core/disco/_convolution.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/fme/core/disco/_convolution.py b/fme/core/disco/_convolution.py index a10824736..a1a522825 100644 --- a/fme/core/disco/_convolution.py +++ b/fme/core/disco/_convolution.py @@ -304,9 +304,14 @@ def __init__( col_idx = idx[2, ...].contiguous() vals = vals.contiguous() - self.register_buffer("psi_ker_idx", ker_idx, persistent=False) - self.register_buffer("psi_row_idx", row_idx, persistent=False) - self.register_buffer("psi_col_idx", col_idx, persistent=False) + # Store index tensors as plain attributes instead of buffers so that + # DDP's _broadcast_coalesced (which cannot handle integer dtypes) does + # not attempt to synchronise them. They are deterministic across ranks + # so synchronisation is unnecessary. Device tracking is handled by the + # _apply override below. + self.psi_ker_idx = ker_idx + self.psi_row_idx = row_idx + self.psi_col_idx = col_idx self.register_buffer("psi_vals", vals, persistent=False) # Precompute banded FFT of psi for FFT-based contraction @@ -323,7 +328,18 @@ def __init__( psi_sparse, self.nlat_in, self.nlon_in ) self.register_buffer("psi_fft_conj", psi_fft_conj, persistent=False) - self.register_buffer("psi_gather_idx", gather_idx, persistent=False) + self.psi_gather_idx = gather_idx + + def _apply(self, fn, recurse=True): + # Required because integer index tensors are stored as plain attributes + # (not buffers) to avoid DDP's _broadcast_coalesced failing on non-float + # dtypes. This ensures they follow .to() / .cuda() device transfers. + super()._apply(fn, recurse=recurse) + self.psi_ker_idx = fn(self.psi_ker_idx) + self.psi_row_idx = fn(self.psi_row_idx) + self.psi_col_idx = fn(self.psi_col_idx) + self.psi_gather_idx = fn(self.psi_gather_idx) + return self def extra_repr(self): return ( From 119c53173f427e2b6124477eaf7d867c01585822 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 26 Mar 2026 17:37:30 +0000 Subject: [PATCH 30/31] Add initial draft of vector filter basis notes Co-Authored-By: Claude Opus 4.6 (1M context) --- vector_filter_basis.md | 237 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 vector_filter_basis.md diff --git a/vector_filter_basis.md b/vector_filter_basis.md new file mode 100644 index 000000000..01fce5452 --- /dev/null +++ b/vector_filter_basis.md @@ -0,0 +1,237 @@ +# DISCO Convolution with Vector-Typed Hidden Features + +## Overview + +This document specifies an approach for handling directional (vector-valued) data in DISCO convolution on the sphere. Rather than constructing a continuous directional filter basis (as in the cube-gauged approach), this design uses **isotropic radial filters** combined with a **frame-rotation matrix** that correctly transforms vector inputs from each input point's local meridian frame into the output point's meridian frame. + +The key insight is that the meridian frame's polar singularity does not pose a problem: the physics is mostly rotationally invariant, so it is acceptable for the reference orientation to be discontinuous at the poles. What matters is that vector-valued inputs are correctly handled — that is, the convolution accounts for the fact that neighboring points have different definitions of "north" and "east." + +The frame-rotation angle between two points depends only on `(lat_out, lat_in, lon_in - lon_out)`, making it invariant to translations along longitude. This means the rotation can be baked into the filter tensor without breaking the FFT-based cross-convolution optimization. + +## Hidden Representation + +The network maintains three types of hidden feature channels at each grid point: + +- **Scalar channels** `s(lat, lon)`: Frame-independent quantities (temperature, pressure, geopotential, etc.). No special treatment needed in convolution. +- **u channels** `u(lat, lon)`: East component of a vector in the local meridian frame (eastward wind, etc.). +- **v channels** `v(lat, lon)`: North component of a vector in the local meridian frame (northward wind, etc.). + +Each `(u, v)` pair represents a tangent vector at the grid point, expressed in the local geographic frame where "east" = direction of increasing longitude and "north" = direction of decreasing colatitude (toward the geographic north pole). + +Initially, `u` and `v` channels are the physical eastward and northward components of the input data. Through the network, new vector channels can be created via pointwise interactions (e.g., scalar times vector), and the convolution propagates them while correctly handling the frame geometry. + +## The Meridian Frame and Its Discontinuity + +The meridian frame is defined everywhere except at the geographic poles, where all meridians converge and "north" is undefined. At the poles, the frame has a rotational discontinuity: approaching the pole from different longitudes gives different limiting orientations. + +This discontinuity is acceptable because: + +1. **The physics is approximately rotationally invariant.** The dominant dynamics (advection, pressure gradients, wave propagation) do not have a preferred horizontal direction. The Coriolis force breaks this symmetry, but it varies smoothly with latitude, not with the frame convention. +2. **The discontinuity is at a measure-zero set of points.** On a discrete grid, the poles are either absent (Gaussian grid) or have a single row of points. The network can learn to handle these points through the local structure of the filter. +3. **Vector channels are correctly rotated within the filter support.** The frame rotation baked into the filter tensor ensures that each input vector is seen in the output point's frame, regardless of how different their meridian frames are. + +## Frame Rotation in the Convolution + +### The Problem + +When a filter centered at an output point looks at a neighboring input point, the input's `(u, v)` components are measured in the input point's meridian frame. If the two points have different "north" directions (which they always do, unless they share a meridian), the raw `u` and `v` values cannot be directly combined — they would be mixing vectors expressed in different coordinate systems. + +### The Solution + +Before the filter integrates over input points, each input vector is rotated from the input point's meridian frame into the output point's meridian frame. This rotation is characterized by a single angle `γ(lat_out, lat_in, Δlon)` — the angle between the two points' north directions, as measured via parallel transport along the connecting geodesic. + +The rotation acts as: + +``` +[u_rotated] [cos γ −sin γ] [u_in] +[v_rotated] = [sin γ cos γ] [v_in] +``` + +where `(u_in, v_in)` are the input vector components in the input's meridian frame, and `(u_rotated, v_rotated)` are the same vector expressed in the output's meridian frame. + +### Longitude Invariance + +The rotation angle `γ` depends on `(lat_out, lat_in, Δlon)` where `Δlon = lon_in − lon_out`. It does **not** depend on absolute longitude. This is because the relative geometry of any two points — their geodesic distance, bearing angles, and frame rotation — is invariant under rotations about the polar axis. + +This means `ψ_k(r) · cos(γ)` and `ψ_k(r) · sin(γ)` are both functions of the longitude difference, just like the scalar filter `ψ_k(r)`. The FFT-based cross-convolution optimization applies to all of them. + +## Computing the Frame Rotation Angle + +### Setup + +The DISCO precomputation uses a YZY Euler rotation (rotation about the y-axis by `α = −θ_out`, where `θ_out` is the output colatitude) to bring the output point to the north pole. For each input point at colatitude `γ` and longitude `λ`, the rotated Cartesian position is: + +``` +x' = cos(α) cos(λ) sin(γ) + sin(α) cos(γ) +y' = sin(λ) sin(γ) +z' = −sin(α) cos(λ) sin(γ) + cos(α) cos(γ) +``` + +The code already computes `θ = arccos(z')` and `φ = atan2(y', x')` from this. + +### Computing γ + +The frame rotation angle `γ = φ − β`, where `β` is the angle that the input point's geographic north makes with the local "toward-pole" direction in the Euler-rotated frame. The components of `β` are: + +``` +cos(β) = ê_N' · (−θ̂) +sin(β) = ê_N' · φ̂ +``` + +where: + +- `ê_N'` is the input point's geographic north direction, Euler-rotated to the new frame: + ``` + ê_N'_x = −cos(α) cos(γ) cos(λ) + sin(α) sin(γ) + ê_N'_y = −cos(γ) sin(λ) + ê_N'_z = sin(α) cos(γ) cos(λ) + cos(α) sin(γ) + ``` +- `−θ̂` and `φ̂` are the local basis vectors at the rotated position `(θ, φ)`: + ``` + −θ̂ = (−cos θ cos φ, −cos θ sin φ, sin θ) + φ̂ = (−sin φ, cos φ, 0 ) + ``` + +Then: + +``` +cos(γ) = cos(φ − β) = cos φ · cos β + sin φ · sin β +sin(γ) = sin(φ − β) = sin φ · cos β − cos φ · sin β +``` + +These quantities can be computed in the existing precomputation loop alongside `θ` and `φ`, adding only a few vector dot products per support point. + +### Geometric Meaning + +The angle `γ` is the total rotation from the input's meridian frame to the output's meridian frame, accounting for: + +1. The input's geographic north direction relative to the Euler-frame local basis (the angle `β`) +2. The azimuthal position of the input point relative to the output point's meridian (the angle `φ`) + +At the output point's location (where `θ → 0`), the Euler frame's `φ = 0` direction corresponds to the output point's geographic south, and `φ = π/2` corresponds to geographic east. The parallel transport of the input's north vector to the output point arrives at Cartesian angle `π + γ` from the x-axis in the output's tangent plane (since north is the `−x` direction). + +## Convolution Operation + +### Filter Tensor for Scalar Channels + +Unchanged from the current implementation. The filter basis functions `ψ_k(r)` depend only on geodesic distance. The precomputed banded FFT tensor is: + +``` +psi_scalar_fft: shape (K, nlat_out, max_bw, nfreq) +``` + +The contraction gives: + +``` +conv_scalar[b, c, k, lat, lon] = Σ_in ψ_k(r) · s_c(lat_in, lon_in) +``` + +### Filter Tensor for Vector Channels + +For vector channels, the filter tensor incorporates the frame rotation. Two banded FFT tensors are precomputed: + +``` +psi_cos_fft: shape (K, nlat_out, max_bw, nfreq) — FFT of ψ_k(r) · cos(γ) +psi_sin_fft: shape (K, nlat_out, max_bw, nfreq) — FFT of ψ_k(r) · sin(γ) +``` + +These are built from the same radial filter basis `ψ_k(r)`, multiplied by `cos(γ)` and `sin(γ)` at each support point before banding and FFT. + +### Vector Convolution + +For each vector input pair `(u_c, v_c)` with radial basis `k`, the frame-rotated convolution gives the vector in the output point's frame: + +``` +conv_u[b, c, k, lat, lon] = contraction(psi_cos, u_c) − contraction(psi_sin, v_c) +conv_v[b, c, k, lat, lon] = contraction(psi_sin, u_c) + contraction(psi_cos, v_c) +``` + +where `contraction(psi, x)` denotes the standard FFT-based DISCO contraction (`_disco_s2_contraction_fft`). + +This requires two contraction calls for vector channels (one with `psi_cos`, one with `psi_sin`), each processing all `u` and `v` channels together. The reassembly into `(conv_u, conv_v)` is a cheap pointwise operation. + +### Weight Contraction + +After convolution, the intermediate features are: + +- From scalar inputs: `conv_scalar` with shape `(B, N_scalar, K, H, W)` +- From vector inputs: `conv_u` and `conv_v`, each with shape `(B, N_vec, K, H, W)` + +All intermediate features are in the output point's meridian frame, so they can be freely mixed by the learned weight tensor. The weight produces output scalar, u, and v channels: + +- **Scalar output from scalar input:** Standard — weight contracts over `(c_in, k)`. +- **Scalar output from vector input:** Weight contracts over `(c_in, k)` separately for `conv_u` and `conv_v`, producing a scalar from vector components (analogous to divergence or a directional projection). +- **Vector output from vector input:** Weight contracts over `(c_in, k)` for `conv_u` and `conv_v` to produce output `u` and `v` components. The weight can independently scale and mix the u/v components. +- **Vector output from scalar input:** The weight assigns a scalar intermediate feature to a vector output channel. This implicitly creates a preferred direction in the meridian frame, and is therefore not rotationally equivariant. But if `u_out` and `v_out` channels are created from the same scalar feature with appropriate weights, the network can learn latitude-dependent directional structures (e.g., patterns aligned with Coriolis deflection). + +In practice, the simplest implementation treats all intermediate features as a flat channel dimension of size `N_scalar + 2 · N_vec` (with `conv_u` and `conv_v` occupying separate channel slots), and uses a single weight tensor of shape `(C_out, C_in_effective, K)` that mixes everything. The type labels (scalar, u, v) are bookkeeping for the next layer's convolution, not a constraint on the weight. + +## Nonlinearities and Vector-Scalar Interactions + +Between convolution layers, pointwise nonlinearities must respect the typed channel structure. All operations below are pointwise (per grid point) and therefore frame-consistent, since all channels at a given grid point share the same meridian frame. + +### Safe Nonlinearities + +- **Scalar channels:** Any standard nonlinearity (ReLU, GELU, etc.) applied independently. +- **Vector norm → scalar:** `n = sqrt(u² + v²)` is frame-invariant and produces a true scalar. +- **Norm-gated vectors:** `σ(n) · (u, v) / n` applies a nonlinearity to the vector magnitude while preserving direction. +- **Scalar × vector → vector:** `s · (u, v)` multiplies a scalar channel with a vector pair to produce a new vector pair. This is how the network creates new vector features from scalar-vector interactions. +- **Vector dot product → scalar:** `u₁·u₂ + v₁·v₂` produces a frame-invariant scalar from two vector pairs. +- **2D cross product → scalar:** `u₁·v₂ − u₂·v₁` produces a pseudo-scalar (changes sign under reflection but is frame-rotation invariant). + +### Unsafe Nonlinearities + +- **Independent nonlinearities on u and v:** Applying `ReLU(u)`, `ReLU(v)` independently breaks rotational structure — the result depends on the frame orientation. +- **Adding scalar to vector component:** `u + s` breaks vector transformation rules. + +## Performance Characteristics + +### Compute Cost + +For a layer with `N_s` scalar input channels and `N_v` vector input pairs: + +| Operation | Filter calls | Relative to scalar-only | +|---|---|---| +| Scalar convolution | 1 call with `psi_scalar` over `N_s` channels | baseline | +| Vector convolution | 2 calls (`psi_cos`, `psi_sin`) over `2·N_v` channels each | 4× per vector pair vs 1× per scalar | + +Each "filter call" is one invocation of `_disco_s2_contraction_fft`. The cost per call scales with `(channels × K × nlat × nlon × log(nlon))`. + +The total cost is `(N_s + 4·N_v) · K` filter applications, compared to `(N_s + 2·N_v) · K` if vector channels were treated as independent scalars (ignoring frame rotation). The overhead is a factor of 2× for the vector channels only. + +### Memory Cost + +Two additional banded FFT tensors (`psi_cos_fft`, `psi_sin_fft`) of the same shape as `psi_scalar_fft` are stored per layer. Total filter storage is 3× the scalar-only case. The extra memory is typically small relative to activations and weights. + +### Comparison with Cube-Gauged Directional Filters + +The cube-gauged approach (documented in `cube_filter_basis.md`) adds angular modes `sin(mφ)`, `cos(mφ)` to the filter basis, giving `1 + 2M` basis functions per radial scale. The post-convolution rotation to the cube frame is cheap (pointwise), and the cube construction provides a globally smooth (up to Z₄ vertex singularities) directional filter. + +This approach is different in character: + +- **No angular filter modes.** The filters remain isotropic (radial only). Directional information enters through the vector channels, not through the filter shape. +- **No cube geometry.** No cube partition, blending zones, transition maps, or equivariant nonlinearity constraints. +- **Frame rotation is per-input-point.** The rotation matrix varies across the filter support (different input points have different frame mismatches), so it must be baked into the filter tensor, doubling the vector channel cost. In contrast, the cube-frame rotation is at the output point only (constant across the sum, negligible cost). +- **Simpler but less expressive filters.** Isotropic filters cannot detect angular structure in scalar fields. Directional sensitivity comes only from vector inputs. This is sufficient if the physically important directional information is carried by vector quantities (wind, currents, etc.), which is typically the case in atmospheric modeling. + +## Design Parameters + +- **K (radial basis size):** Number of radial filter basis functions per convolution layer. Same as the current `kernel_shape` for `IsotropicMorletFilterBasis`. Controls the radial resolution of the filter. +- **N_s, N_v (channel counts):** Number of scalar and vector channel pairs in the hidden representation. The network architecture determines these. A reasonable starting point is to match the current total channel count, reserving some for vector pairs. +- **Nonlinearity design:** The choice of how to combine scalar and vector channels between layers (norm-gating, scalar-vector products, etc.) is an architectural decision. Simple norm-gating plus scalar×vector products provide the essential interactions. + +## Summary of Implementation Steps + +1. **Extend the precomputation loop.** In `_precompute_convolution_tensor_s2`, after computing `(θ, φ)` for each support point, additionally compute `cos(γ)` and `sin(γ)` from the input point's Euler-rotated north direction (as described in "Computing the Frame Rotation Angle"). Store three sets of sparse filter values: `ψ_k(r)` (scalar), `ψ_k(r)·cos(γ)` (vector cosine), `ψ_k(r)·sin(γ)` (vector sine). + +2. **Build banded FFT tensors.** Run `_precompute_psi_banded` three times to produce `psi_scalar_fft`, `psi_cos_fft`, and `psi_sin_fft`. All three share the same sparsity pattern (same support points), so the banding and gather indices are identical. + +3. **Modify the forward pass.** The convolution layer: + - Applies `psi_scalar_fft` to scalar input channels → scalar intermediate features. + - Applies `psi_cos_fft` and `psi_sin_fft` to vector input channels (u and v concatenated) → four intermediate tensors. + - Reassembles the four vector intermediates into `(conv_u, conv_v)` via the rotation formula. + - Concatenates scalar and vector intermediates and contracts with the weight tensor. + +4. **Type-aware channel management.** Track which output channels are scalar vs. vector (u, v). This metadata is used by the next convolution layer to determine which channels need frame rotation and by the nonlinearity to apply type-appropriate operations. + +5. **Implement equivariant nonlinearities.** Between convolution layers, apply norm-gated activations to vector pairs and standard activations to scalar channels. Optionally include scalar×vector product layers to enable cross-type feature creation. From 7dae1a22641248ca568b82a4d3408406d902e014 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 26 Mar 2026 18:20:55 +0000 Subject: [PATCH 31/31] Revert "Add initial draft of vector filter basis notes" This reverts commit 119c53173f427e2b6124477eaf7d867c01585822. --- vector_filter_basis.md | 237 ----------------------------------------- 1 file changed, 237 deletions(-) delete mode 100644 vector_filter_basis.md diff --git a/vector_filter_basis.md b/vector_filter_basis.md deleted file mode 100644 index 01fce5452..000000000 --- a/vector_filter_basis.md +++ /dev/null @@ -1,237 +0,0 @@ -# DISCO Convolution with Vector-Typed Hidden Features - -## Overview - -This document specifies an approach for handling directional (vector-valued) data in DISCO convolution on the sphere. Rather than constructing a continuous directional filter basis (as in the cube-gauged approach), this design uses **isotropic radial filters** combined with a **frame-rotation matrix** that correctly transforms vector inputs from each input point's local meridian frame into the output point's meridian frame. - -The key insight is that the meridian frame's polar singularity does not pose a problem: the physics is mostly rotationally invariant, so it is acceptable for the reference orientation to be discontinuous at the poles. What matters is that vector-valued inputs are correctly handled — that is, the convolution accounts for the fact that neighboring points have different definitions of "north" and "east." - -The frame-rotation angle between two points depends only on `(lat_out, lat_in, lon_in - lon_out)`, making it invariant to translations along longitude. This means the rotation can be baked into the filter tensor without breaking the FFT-based cross-convolution optimization. - -## Hidden Representation - -The network maintains three types of hidden feature channels at each grid point: - -- **Scalar channels** `s(lat, lon)`: Frame-independent quantities (temperature, pressure, geopotential, etc.). No special treatment needed in convolution. -- **u channels** `u(lat, lon)`: East component of a vector in the local meridian frame (eastward wind, etc.). -- **v channels** `v(lat, lon)`: North component of a vector in the local meridian frame (northward wind, etc.). - -Each `(u, v)` pair represents a tangent vector at the grid point, expressed in the local geographic frame where "east" = direction of increasing longitude and "north" = direction of decreasing colatitude (toward the geographic north pole). - -Initially, `u` and `v` channels are the physical eastward and northward components of the input data. Through the network, new vector channels can be created via pointwise interactions (e.g., scalar times vector), and the convolution propagates them while correctly handling the frame geometry. - -## The Meridian Frame and Its Discontinuity - -The meridian frame is defined everywhere except at the geographic poles, where all meridians converge and "north" is undefined. At the poles, the frame has a rotational discontinuity: approaching the pole from different longitudes gives different limiting orientations. - -This discontinuity is acceptable because: - -1. **The physics is approximately rotationally invariant.** The dominant dynamics (advection, pressure gradients, wave propagation) do not have a preferred horizontal direction. The Coriolis force breaks this symmetry, but it varies smoothly with latitude, not with the frame convention. -2. **The discontinuity is at a measure-zero set of points.** On a discrete grid, the poles are either absent (Gaussian grid) or have a single row of points. The network can learn to handle these points through the local structure of the filter. -3. **Vector channels are correctly rotated within the filter support.** The frame rotation baked into the filter tensor ensures that each input vector is seen in the output point's frame, regardless of how different their meridian frames are. - -## Frame Rotation in the Convolution - -### The Problem - -When a filter centered at an output point looks at a neighboring input point, the input's `(u, v)` components are measured in the input point's meridian frame. If the two points have different "north" directions (which they always do, unless they share a meridian), the raw `u` and `v` values cannot be directly combined — they would be mixing vectors expressed in different coordinate systems. - -### The Solution - -Before the filter integrates over input points, each input vector is rotated from the input point's meridian frame into the output point's meridian frame. This rotation is characterized by a single angle `γ(lat_out, lat_in, Δlon)` — the angle between the two points' north directions, as measured via parallel transport along the connecting geodesic. - -The rotation acts as: - -``` -[u_rotated] [cos γ −sin γ] [u_in] -[v_rotated] = [sin γ cos γ] [v_in] -``` - -where `(u_in, v_in)` are the input vector components in the input's meridian frame, and `(u_rotated, v_rotated)` are the same vector expressed in the output's meridian frame. - -### Longitude Invariance - -The rotation angle `γ` depends on `(lat_out, lat_in, Δlon)` where `Δlon = lon_in − lon_out`. It does **not** depend on absolute longitude. This is because the relative geometry of any two points — their geodesic distance, bearing angles, and frame rotation — is invariant under rotations about the polar axis. - -This means `ψ_k(r) · cos(γ)` and `ψ_k(r) · sin(γ)` are both functions of the longitude difference, just like the scalar filter `ψ_k(r)`. The FFT-based cross-convolution optimization applies to all of them. - -## Computing the Frame Rotation Angle - -### Setup - -The DISCO precomputation uses a YZY Euler rotation (rotation about the y-axis by `α = −θ_out`, where `θ_out` is the output colatitude) to bring the output point to the north pole. For each input point at colatitude `γ` and longitude `λ`, the rotated Cartesian position is: - -``` -x' = cos(α) cos(λ) sin(γ) + sin(α) cos(γ) -y' = sin(λ) sin(γ) -z' = −sin(α) cos(λ) sin(γ) + cos(α) cos(γ) -``` - -The code already computes `θ = arccos(z')` and `φ = atan2(y', x')` from this. - -### Computing γ - -The frame rotation angle `γ = φ − β`, where `β` is the angle that the input point's geographic north makes with the local "toward-pole" direction in the Euler-rotated frame. The components of `β` are: - -``` -cos(β) = ê_N' · (−θ̂) -sin(β) = ê_N' · φ̂ -``` - -where: - -- `ê_N'` is the input point's geographic north direction, Euler-rotated to the new frame: - ``` - ê_N'_x = −cos(α) cos(γ) cos(λ) + sin(α) sin(γ) - ê_N'_y = −cos(γ) sin(λ) - ê_N'_z = sin(α) cos(γ) cos(λ) + cos(α) sin(γ) - ``` -- `−θ̂` and `φ̂` are the local basis vectors at the rotated position `(θ, φ)`: - ``` - −θ̂ = (−cos θ cos φ, −cos θ sin φ, sin θ) - φ̂ = (−sin φ, cos φ, 0 ) - ``` - -Then: - -``` -cos(γ) = cos(φ − β) = cos φ · cos β + sin φ · sin β -sin(γ) = sin(φ − β) = sin φ · cos β − cos φ · sin β -``` - -These quantities can be computed in the existing precomputation loop alongside `θ` and `φ`, adding only a few vector dot products per support point. - -### Geometric Meaning - -The angle `γ` is the total rotation from the input's meridian frame to the output's meridian frame, accounting for: - -1. The input's geographic north direction relative to the Euler-frame local basis (the angle `β`) -2. The azimuthal position of the input point relative to the output point's meridian (the angle `φ`) - -At the output point's location (where `θ → 0`), the Euler frame's `φ = 0` direction corresponds to the output point's geographic south, and `φ = π/2` corresponds to geographic east. The parallel transport of the input's north vector to the output point arrives at Cartesian angle `π + γ` from the x-axis in the output's tangent plane (since north is the `−x` direction). - -## Convolution Operation - -### Filter Tensor for Scalar Channels - -Unchanged from the current implementation. The filter basis functions `ψ_k(r)` depend only on geodesic distance. The precomputed banded FFT tensor is: - -``` -psi_scalar_fft: shape (K, nlat_out, max_bw, nfreq) -``` - -The contraction gives: - -``` -conv_scalar[b, c, k, lat, lon] = Σ_in ψ_k(r) · s_c(lat_in, lon_in) -``` - -### Filter Tensor for Vector Channels - -For vector channels, the filter tensor incorporates the frame rotation. Two banded FFT tensors are precomputed: - -``` -psi_cos_fft: shape (K, nlat_out, max_bw, nfreq) — FFT of ψ_k(r) · cos(γ) -psi_sin_fft: shape (K, nlat_out, max_bw, nfreq) — FFT of ψ_k(r) · sin(γ) -``` - -These are built from the same radial filter basis `ψ_k(r)`, multiplied by `cos(γ)` and `sin(γ)` at each support point before banding and FFT. - -### Vector Convolution - -For each vector input pair `(u_c, v_c)` with radial basis `k`, the frame-rotated convolution gives the vector in the output point's frame: - -``` -conv_u[b, c, k, lat, lon] = contraction(psi_cos, u_c) − contraction(psi_sin, v_c) -conv_v[b, c, k, lat, lon] = contraction(psi_sin, u_c) + contraction(psi_cos, v_c) -``` - -where `contraction(psi, x)` denotes the standard FFT-based DISCO contraction (`_disco_s2_contraction_fft`). - -This requires two contraction calls for vector channels (one with `psi_cos`, one with `psi_sin`), each processing all `u` and `v` channels together. The reassembly into `(conv_u, conv_v)` is a cheap pointwise operation. - -### Weight Contraction - -After convolution, the intermediate features are: - -- From scalar inputs: `conv_scalar` with shape `(B, N_scalar, K, H, W)` -- From vector inputs: `conv_u` and `conv_v`, each with shape `(B, N_vec, K, H, W)` - -All intermediate features are in the output point's meridian frame, so they can be freely mixed by the learned weight tensor. The weight produces output scalar, u, and v channels: - -- **Scalar output from scalar input:** Standard — weight contracts over `(c_in, k)`. -- **Scalar output from vector input:** Weight contracts over `(c_in, k)` separately for `conv_u` and `conv_v`, producing a scalar from vector components (analogous to divergence or a directional projection). -- **Vector output from vector input:** Weight contracts over `(c_in, k)` for `conv_u` and `conv_v` to produce output `u` and `v` components. The weight can independently scale and mix the u/v components. -- **Vector output from scalar input:** The weight assigns a scalar intermediate feature to a vector output channel. This implicitly creates a preferred direction in the meridian frame, and is therefore not rotationally equivariant. But if `u_out` and `v_out` channels are created from the same scalar feature with appropriate weights, the network can learn latitude-dependent directional structures (e.g., patterns aligned with Coriolis deflection). - -In practice, the simplest implementation treats all intermediate features as a flat channel dimension of size `N_scalar + 2 · N_vec` (with `conv_u` and `conv_v` occupying separate channel slots), and uses a single weight tensor of shape `(C_out, C_in_effective, K)` that mixes everything. The type labels (scalar, u, v) are bookkeeping for the next layer's convolution, not a constraint on the weight. - -## Nonlinearities and Vector-Scalar Interactions - -Between convolution layers, pointwise nonlinearities must respect the typed channel structure. All operations below are pointwise (per grid point) and therefore frame-consistent, since all channels at a given grid point share the same meridian frame. - -### Safe Nonlinearities - -- **Scalar channels:** Any standard nonlinearity (ReLU, GELU, etc.) applied independently. -- **Vector norm → scalar:** `n = sqrt(u² + v²)` is frame-invariant and produces a true scalar. -- **Norm-gated vectors:** `σ(n) · (u, v) / n` applies a nonlinearity to the vector magnitude while preserving direction. -- **Scalar × vector → vector:** `s · (u, v)` multiplies a scalar channel with a vector pair to produce a new vector pair. This is how the network creates new vector features from scalar-vector interactions. -- **Vector dot product → scalar:** `u₁·u₂ + v₁·v₂` produces a frame-invariant scalar from two vector pairs. -- **2D cross product → scalar:** `u₁·v₂ − u₂·v₁` produces a pseudo-scalar (changes sign under reflection but is frame-rotation invariant). - -### Unsafe Nonlinearities - -- **Independent nonlinearities on u and v:** Applying `ReLU(u)`, `ReLU(v)` independently breaks rotational structure — the result depends on the frame orientation. -- **Adding scalar to vector component:** `u + s` breaks vector transformation rules. - -## Performance Characteristics - -### Compute Cost - -For a layer with `N_s` scalar input channels and `N_v` vector input pairs: - -| Operation | Filter calls | Relative to scalar-only | -|---|---|---| -| Scalar convolution | 1 call with `psi_scalar` over `N_s` channels | baseline | -| Vector convolution | 2 calls (`psi_cos`, `psi_sin`) over `2·N_v` channels each | 4× per vector pair vs 1× per scalar | - -Each "filter call" is one invocation of `_disco_s2_contraction_fft`. The cost per call scales with `(channels × K × nlat × nlon × log(nlon))`. - -The total cost is `(N_s + 4·N_v) · K` filter applications, compared to `(N_s + 2·N_v) · K` if vector channels were treated as independent scalars (ignoring frame rotation). The overhead is a factor of 2× for the vector channels only. - -### Memory Cost - -Two additional banded FFT tensors (`psi_cos_fft`, `psi_sin_fft`) of the same shape as `psi_scalar_fft` are stored per layer. Total filter storage is 3× the scalar-only case. The extra memory is typically small relative to activations and weights. - -### Comparison with Cube-Gauged Directional Filters - -The cube-gauged approach (documented in `cube_filter_basis.md`) adds angular modes `sin(mφ)`, `cos(mφ)` to the filter basis, giving `1 + 2M` basis functions per radial scale. The post-convolution rotation to the cube frame is cheap (pointwise), and the cube construction provides a globally smooth (up to Z₄ vertex singularities) directional filter. - -This approach is different in character: - -- **No angular filter modes.** The filters remain isotropic (radial only). Directional information enters through the vector channels, not through the filter shape. -- **No cube geometry.** No cube partition, blending zones, transition maps, or equivariant nonlinearity constraints. -- **Frame rotation is per-input-point.** The rotation matrix varies across the filter support (different input points have different frame mismatches), so it must be baked into the filter tensor, doubling the vector channel cost. In contrast, the cube-frame rotation is at the output point only (constant across the sum, negligible cost). -- **Simpler but less expressive filters.** Isotropic filters cannot detect angular structure in scalar fields. Directional sensitivity comes only from vector inputs. This is sufficient if the physically important directional information is carried by vector quantities (wind, currents, etc.), which is typically the case in atmospheric modeling. - -## Design Parameters - -- **K (radial basis size):** Number of radial filter basis functions per convolution layer. Same as the current `kernel_shape` for `IsotropicMorletFilterBasis`. Controls the radial resolution of the filter. -- **N_s, N_v (channel counts):** Number of scalar and vector channel pairs in the hidden representation. The network architecture determines these. A reasonable starting point is to match the current total channel count, reserving some for vector pairs. -- **Nonlinearity design:** The choice of how to combine scalar and vector channels between layers (norm-gating, scalar-vector products, etc.) is an architectural decision. Simple norm-gating plus scalar×vector products provide the essential interactions. - -## Summary of Implementation Steps - -1. **Extend the precomputation loop.** In `_precompute_convolution_tensor_s2`, after computing `(θ, φ)` for each support point, additionally compute `cos(γ)` and `sin(γ)` from the input point's Euler-rotated north direction (as described in "Computing the Frame Rotation Angle"). Store three sets of sparse filter values: `ψ_k(r)` (scalar), `ψ_k(r)·cos(γ)` (vector cosine), `ψ_k(r)·sin(γ)` (vector sine). - -2. **Build banded FFT tensors.** Run `_precompute_psi_banded` three times to produce `psi_scalar_fft`, `psi_cos_fft`, and `psi_sin_fft`. All three share the same sparsity pattern (same support points), so the banding and gather indices are identical. - -3. **Modify the forward pass.** The convolution layer: - - Applies `psi_scalar_fft` to scalar input channels → scalar intermediate features. - - Applies `psi_cos_fft` and `psi_sin_fft` to vector input channels (u and v concatenated) → four intermediate tensors. - - Reassembles the four vector intermediates into `(conv_u, conv_v)` via the rotation formula. - - Concatenates scalar and vector intermediates and contracts with the weight tensor. - -4. **Type-aware channel management.** Track which output channels are scalar vs. vector (u, v). This metadata is used by the next convolution layer to determine which channels need frame rotation and by the nonlinearity to apply type-appropriate operations. - -5. **Implement equivariant nonlinearities.** Between convolution layers, apply norm-gated activations to vector pairs and standard activations to scalar channels. Optionally include scalar×vector product layers to enable cross-type feature creation.