-
Notifications
You must be signed in to change notification settings - Fork 39
Add secondary network support to SingleModuleStep and register local network types #1002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mcgibbon
wants to merge
40
commits into
main
Choose a base branch
from
feature/secondary_stepping
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
012b92a
Add LocalNet module using only local (DISCO/conv1x1) filters
mcgibbon 4de511d
Simplify LocalNet config: remove dropout/checkpointing, use block_typ…
mcgibbon 7fadfee
Add AnkurLocalNet replicating Ankur's diagnostic MLP architecture
mcgibbon b1567ff
Move data_grid from config dataclasses to factory function arguments
mcgibbon c8950fa
Add HybridNet combining SFNO backbone with local diagnostic network
mcgibbon d2afaef
Merge branch 'main' into feature/separate_input_types
mcgibbon e42d56a
Remove superfluous HybridNet tests for sub-network internals
mcgibbon 1c5b252
Add optional secondary network to SingleModuleStep and remove HybridNet
mcgibbon 473a53d
Add AnkurLocalNet and LocalNet to the module registry
mcgibbon 703f138
Use training step instead of parameter surgery in noise stochasticity…
mcgibbon 48c7c5b
Extract shared NoiseConditionedModule base class
mcgibbon 34abe3a
Use composition instead of inheritance for noise conditioning
mcgibbon b5e451c
Add regression tests for AnkurLocalNet and LocalNet
mcgibbon 97e70a7
Split secondary network outputs into full-field and residual lists
mcgibbon 708d634
Add isotropic noise support for LocalNet
mcgibbon a525a64
Consolidate noise conditioning into NoiseConditionedModel in stochast…
mcgibbon 7ff1be1
Remove data_grid config from local network builders, hard-code to leg…
mcgibbon b806b9a
Add __init__ docstring to NoiseConditionedModel
mcgibbon 9e23ed8
Merge branch 'main' into feature/separate_input_types
mcgibbon 47c420c
Add regression tests for AnkurLocalNet and LocalNet
mcgibbon 66e9e66
Remove HybridNet and its tests
mcgibbon f62ce1b
Merge branch 'feature/separate_input_types' into feature/secondary_st…
mcgibbon 25c70ee
Merge branch 'main' into feature/secondary_stepping
mcgibbon d514f3c
Address PR review feedback: update defaults, remove duplicate docstri…
mcgibbon da2f020
Merge branch 'feature/secondary_stepping' of github.com:ai2cm/ace int…
mcgibbon 0009d7e
Fix backwards compatibility: restore conditional_model attribute name
mcgibbon 607932f
Add kernel_shape, basis_type config to LocalNet and isotropic Morlet …
mcgibbon ad163cf
Merge branch 'main' into feature/secondary_stepping
mcgibbon 4c98ea3
Add test that isotropic DISCO conv commutes with latitude reflection
mcgibbon c0a911d
Fork DISCO convolution with FFT-based contraction into fme/core/disco
mcgibbon 4d0dfe2
Remove torch-harmonics dependency from fme/core/disco
mcgibbon 0a0711a
Merge branch 'feature/faster_disco' into feature/secondary_stepping
mcgibbon f84aae3
Replace isotropic Morlet monkey-patch with native filter basis
mcgibbon 138c72d
Fix DISCO test compatibility with torch-harmonics 0.8.0
mcgibbon cb4ca83
Merge branch 'main' into feature/secondary_stepping
mcgibbon 4a43eb5
Merge branch 'main' into feature/secondary_stepping
mcgibbon a024340
Add arxiv reference to AnkurLocalNet documentation
mcgibbon e2b7559
Fix DDP crash on integer DISCO index buffers
mcgibbon 119c531
Add initial draft of vector filter basis notes
mcgibbon 7dae1a2
Revert "Add initial draft of vector filter basis notes"
mcgibbon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| import dataclasses | ||
| from typing import Literal | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| from fme.ace.registry.registry import ModuleConfig, ModuleSelector | ||
| from fme.ace.registry.stochastic_sfno import NoiseConditionedModel | ||
| from fme.core.dataset_info import DatasetInfo | ||
| from fme.core.distributed import Distributed | ||
| from fme.core.models.conditional_sfno.ankur import ( | ||
| AnkurLocalNetConfig, | ||
| get_lat_lon_ankur_localnet, | ||
| ) | ||
| from fme.core.models.conditional_sfno.layers import Context, ContextConfig | ||
| from fme.core.models.conditional_sfno.localnet import ( | ||
| BasisType, | ||
| BlockType, | ||
| LocalNetConfig, | ||
| get_lat_lon_localnet, | ||
| ) | ||
|
|
||
|
|
||
| class _ContextWrappedModule(nn.Module): | ||
| """Wraps a module that takes (x, context: Context) to accept (x, labels=None). | ||
|
|
||
| This adapts the conditional_sfno forward signature to the interface | ||
| expected by the Module registry wrapper. | ||
| """ | ||
|
|
||
| def __init__(self, module: nn.Module): | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward( | ||
| self, x: torch.Tensor, labels: torch.Tensor | None = None | ||
| ) -> torch.Tensor: | ||
| context = Context( | ||
| embedding_scalar=None, | ||
| embedding_pos=None, | ||
| labels=labels, | ||
| noise=None, | ||
| ) | ||
| return self.module(x, context) | ||
|
|
||
|
|
||
| @ModuleSelector.register("AnkurLocalNet") | ||
| @dataclasses.dataclass | ||
| class AnkurLocalNetBuilder(ModuleConfig): | ||
| """Configuration for the AnkurLocalNet architecture. | ||
|
|
||
| A simple 3-hidden-layer MLP that optionally uses a DISCO convolution | ||
| for the first layer and a learned positional embedding. | ||
|
|
||
| Reference: https://arxiv.org/html/2602.16090v1 | ||
|
|
||
| Attributes: | ||
| embed_dim: Dimension of the hidden layers. | ||
| use_disco_encoder: Whether to use a DISCO convolution for the first | ||
| layer instead of a 1x1 convolution. | ||
| disco_kernel_size: Kernel size for the DISCO convolution if used. | ||
| pos_embed: Whether to add a learned positional embedding after the | ||
| first layer. | ||
| activation_function: Activation function name ('relu', 'gelu', 'silu'). | ||
| """ | ||
|
|
||
| embed_dim: int = 256 | ||
| use_disco_encoder: bool = False | ||
| disco_kernel_size: int = 3 | ||
| pos_embed: bool = False | ||
| activation_function: str = "gelu" | ||
|
|
||
| def build( | ||
| self, | ||
| n_in_channels: int, | ||
| n_out_channels: int, | ||
| dataset_info: DatasetInfo, | ||
| ) -> nn.Module: | ||
| params = AnkurLocalNetConfig( | ||
| embed_dim=self.embed_dim, | ||
| use_disco_encoder=self.use_disco_encoder, | ||
| disco_kernel_size=self.disco_kernel_size, | ||
| pos_embed=self.pos_embed, | ||
| activation_function=self.activation_function, | ||
| ) | ||
| context_config = ContextConfig( | ||
| embed_dim_scalar=0, | ||
| embed_dim_noise=0, | ||
| embed_dim_labels=len(dataset_info.all_labels), | ||
| embed_dim_pos=0, | ||
| ) | ||
| net = get_lat_lon_ankur_localnet( | ||
| params=params, | ||
| in_chans=n_in_channels, | ||
| out_chans=n_out_channels, | ||
| img_shape=dataset_info.img_shape, | ||
| data_grid="legendre-gauss", | ||
| context_config=context_config, | ||
| ) | ||
| return _ContextWrappedModule(net) | ||
|
|
||
|
|
||
| @ModuleSelector.register("LocalNet") | ||
| @dataclasses.dataclass | ||
| class LocalNetBuilder(ModuleConfig): | ||
| """Configuration for the LocalNet architecture. | ||
|
|
||
| A noise-conditioned local neural operator network using DISCO convolutions | ||
| and/or 1x1 convolutions, with encoder/decoder structure and optional skip | ||
| connections. Supports label conditioning when used with conditional=True | ||
| on the ModuleSelector. | ||
|
|
||
| Attributes: | ||
| embed_dim: Dimension of the embeddings. | ||
| kernel_shape: Shape of the DISCO convolution filter basis, passed | ||
| to the filter basis constructor. For the "piecewise linear" and | ||
| "morlet" basis types this is a two-element tuple | ||
| (n_radial_modes, n_azimuthal_modes). When n_azimuthal_modes is | ||
| 1, the "piecewise linear" basis produces isotropic (radially | ||
| symmetric) filters. Only affects 'disco' blocks. | ||
| basis_type: Type of filter basis for the DISCO convolution | ||
| ('morlet', 'piecewise linear', or 'zernike'). Only affects | ||
| 'disco' blocks. | ||
| noise_embed_dim: Dimension of the noise conditioning channels. | ||
| noise_type: Type of noise for conditioning ('gaussian' or 'isotropic'). | ||
| Isotropic noise is generated via inverse spherical harmonic | ||
| transform. | ||
| context_pos_embed_dim: Dimension of the learned positional embedding | ||
| used for conditioning. 0 disables. | ||
| block_types: List of filter types for each block ('disco', 'conv1x1'). | ||
| The length determines the number of blocks. | ||
| global_layer_norm: Whether to reduce along the spatial domain when | ||
| applying layer normalization. | ||
| use_mlp: Whether to use an MLP in each block. | ||
| mlp_ratio: Ratio of MLP hidden dimension to the embedding dimension. | ||
| activation_function: Activation function name ('relu', 'gelu', 'silu'). | ||
| encoder_layers: Number of convolutional layers in the encoder/decoder. | ||
| pos_embed: Whether to use a learned positional embedding inside the | ||
| LocalNet (distinct from context_pos_embed_dim). | ||
| big_skip: Whether to use a big skip connection from input to decoder. | ||
| normalize_big_skip: Whether to normalize the big skip connection. | ||
| affine_norms: Whether to use element-wise affine parameters in the | ||
| normalization layers. | ||
| lora_rank: Rank of LoRA adaptations. 0 disables LoRA. | ||
| lora_alpha: Strength of LoRA adaptations. Defaults to lora_rank | ||
| if None. | ||
| """ | ||
|
|
||
| embed_dim: int = 256 | ||
| kernel_shape: tuple[int, int] = (3, 3) | ||
| basis_type: BasisType = "morlet" | ||
| noise_embed_dim: int = 256 | ||
| noise_type: Literal["gaussian", "isotropic"] = "gaussian" | ||
| context_pos_embed_dim: int = 0 | ||
| block_types: list[BlockType] = dataclasses.field( | ||
| default_factory=lambda: [ | ||
| "disco", | ||
| "disco", | ||
| "disco", | ||
| "disco", | ||
| "conv1x1", | ||
| "conv1x1", | ||
| "conv1x1", | ||
| "conv1x1", | ||
| ] | ||
| ) | ||
| global_layer_norm: bool = False | ||
| use_mlp: bool = True | ||
| mlp_ratio: float = 2.0 | ||
| activation_function: str = "gelu" | ||
| encoder_layers: int = 1 | ||
| pos_embed: bool = True | ||
| big_skip: bool = True | ||
| normalize_big_skip: bool = False | ||
| affine_norms: bool = False | ||
| lora_rank: int = 0 | ||
| lora_alpha: float | None = None | ||
|
|
||
| def build( | ||
| self, | ||
| n_in_channels: int, | ||
| n_out_channels: int, | ||
| dataset_info: DatasetInfo, | ||
| ) -> nn.Module: | ||
| params = LocalNetConfig( | ||
| embed_dim=self.embed_dim, | ||
| kernel_shape=self.kernel_shape, | ||
| basis_type=self.basis_type, | ||
| block_types=self.block_types, | ||
| global_layer_norm=self.global_layer_norm, | ||
| use_mlp=self.use_mlp, | ||
| mlp_ratio=self.mlp_ratio, | ||
| activation_function=self.activation_function, | ||
| encoder_layers=self.encoder_layers, | ||
| pos_embed=self.pos_embed, | ||
| big_skip=self.big_skip, | ||
| normalize_big_skip=self.normalize_big_skip, | ||
| affine_norms=self.affine_norms, | ||
| lora_rank=self.lora_rank, | ||
| lora_alpha=self.lora_alpha, | ||
| ) | ||
| embed_dim_labels = len(dataset_info.all_labels) | ||
| context_config = ContextConfig( | ||
| embed_dim_scalar=0, | ||
| embed_dim_noise=self.noise_embed_dim, | ||
| embed_dim_labels=embed_dim_labels, | ||
| embed_dim_pos=self.context_pos_embed_dim, | ||
| ) | ||
| net = get_lat_lon_localnet( | ||
| params=params, | ||
| in_chans=n_in_channels, | ||
| out_chans=n_out_channels, | ||
| img_shape=dataset_info.img_shape, | ||
| data_grid="legendre-gauss", | ||
| context_config=context_config, | ||
| ) | ||
| img_shape = dataset_info.img_shape | ||
| if self.noise_type == "isotropic": | ||
| dist = Distributed.get_instance() | ||
| inverse_sht = dist.get_isht(*img_shape, grid="legendre-gauss") | ||
| lmax = inverse_sht.lmax | ||
| mmax = inverse_sht.mmax | ||
| else: | ||
| inverse_sht = None | ||
| lmax = 0 | ||
| mmax = 0 | ||
| return NoiseConditionedModel( | ||
| net, | ||
| img_shape=img_shape, | ||
| embed_dim_noise=self.noise_embed_dim, | ||
| embed_dim_pos=self.context_pos_embed_dim, | ||
| embed_dim_labels=embed_dim_labels, | ||
| inverse_sht=inverse_sht, | ||
| lmax=lmax, | ||
| mmax=mmax, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Is this what we want for default? It seems a little odd to default to a rather deep (12) layer local model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I think we will want to tune this and I have no idea what it will look like in the end, but for now I think I want to start at 4 disco blocks followed by 4 conv1x1 blocks. This gives similar depth to what we have in the SFNO case, while biasing it towards more local analysis. I do think we should also test with as low as 2 blocks (and of course Ankur's case which only uses 1 DISCO layer) and compare. For what it's worth, this is analogous to the SFNO class where the default config uses 12 blocks but we always override it to 8.
Unfortunately to some degree this will be resolution-dependent as well, so we may have to tune both on 4 and 1-degree data.