diff --git a/fme/ace/registry/stochastic_sfno.py b/fme/ace/registry/stochastic_sfno.py index e0d9ae374..edfe79c7c 100644 --- a/fme/ace/registry/stochastic_sfno.py +++ b/fme/ace/registry/stochastic_sfno.py @@ -76,7 +76,7 @@ def __init__( mmax: int = 0, ): super().__init__() - self.conditional_model = conditional_model + self.conditional_model = module self.embed_dim = embed_dim_noise self.img_shape = img_shape self._inverse_sht = inverse_sht diff --git a/fme/ace/registry/test_stochastic_sfno.py b/fme/ace/registry/test_stochastic_sfno.py index a835544ac..e8d23743a 100644 --- a/fme/ace/registry/test_stochastic_sfno.py +++ b/fme/ace/registry/test_stochastic_sfno.py @@ -36,7 +36,7 @@ def test_noise_conditioned_sfno_conditioning(): n_pos = 8 n_labels = 4 model = NoiseConditionedSFNO( - conditional_model=mock_sfno, + module=mock_sfno, img_shape=img_shape, embed_dim_noise=n_noise, embed_dim_pos=n_pos, diff --git a/fme/core/disco/__init__.py b/fme/core/disco/__init__.py new file mode 100644 index 000000000..9de56cc9e --- /dev/null +++ b/fme/core/disco/__init__.py @@ -0,0 +1,3 @@ +from ._convolution import DiscreteContinuousConvS2 + +__all__ = ["DiscreteContinuousConvS2"] diff --git a/fme/core/disco/_cache.py b/fme/core/disco/_cache.py new file mode 100644 index 000000000..26f663b47 --- /dev/null +++ b/fme/core/disco/_cache.py @@ -0,0 +1,24 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +from copy import deepcopy + + +def lru_cache(maxsize=20, typed=False, copy=False): + """LRU cache decorator with optional deep copying of cached results.""" + + def decorator(f): + cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f) + + def wrapper(*args, **kwargs): + res = cached_func(*args, **kwargs) + if copy: + return deepcopy(res) + else: + return res + + return wrapper + + return decorator diff --git a/fme/core/disco/_convolution.py b/fme/core/disco/_convolution.py new file mode 100644 index 000000000..a1a522825 --- /dev/null +++ b/fme/core/disco/_convolution.py @@ -0,0 +1,385 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause +# +# Modifications: stripped down to DiscreteContinuousConvS2 (forward only) +# using the FFT-based contraction as the sole implementation path. CUDA +# kernel and torch sparse-matrix fallbacks are removed. + +import abc +import math + +import torch +import torch.nn as nn + +from . import _filter_basis as _filter_basis_module +from ._cache import lru_cache +from ._disco_utils import _disco_s2_contraction_fft, _get_psi, _precompute_psi_banded +from ._filter_basis import FilterBasis +from ._quadrature import precompute_latitudes, precompute_longitudes + + +def _normalize_convolution_tensor_s2( + psi_idx, + psi_vals, + in_shape, + out_shape, + kernel_size, + quad_weights, + transpose_normalization=False, + basis_norm_mode="mean", + merge_quadrature=False, + eps=1e-9, +): + """Normalizes convolution tensor values based on specified normalization + mode. + """ + if basis_norm_mode == "none": + return psi_vals + + idx = torch.stack( + [ + psi_idx[0], + psi_idx[1], + psi_idx[2] // in_shape[1], + psi_idx[2] % in_shape[1], + ], + dim=0, + ) + + ikernel = idx[0] + + if transpose_normalization: + ilat_out = idx[2] + ilat_in = idx[1] + nlat_out = in_shape[0] + correction_factor = out_shape[1] / in_shape[1] + else: + ilat_out = idx[1] + ilat_in = idx[2] + nlat_out = out_shape[0] + + q = quad_weights[ilat_in].reshape(-1) + + vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) + support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device) + + for ik in range(kernel_size): + for ilat in range(nlat_out): + iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) + vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx]) + support[ik, ilat] = torch.sum(q[iidx]) + + for ik in range(kernel_size): + for ilat in range(nlat_out): + iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat)) + + if basis_norm_mode == "individual": + val = vnorm[ik, ilat] + elif basis_norm_mode == "mean": + val = vnorm[ik, :].mean() + elif basis_norm_mode == "support": + val = support[ik, ilat] + elif basis_norm_mode == "none": + val = 1.0 + else: + raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.") + + psi_vals[iidx] = psi_vals[iidx] / (val + eps) + + if merge_quadrature: + psi_vals[iidx] = psi_vals[iidx] * q[iidx] + + if transpose_normalization and merge_quadrature: + psi_vals = psi_vals / correction_factor + + return psi_vals + + +@lru_cache(typed=True, copy=True) +def _precompute_convolution_tensor_s2( + in_shape: tuple[int], + out_shape: tuple[int], + filter_basis: FilterBasis, + grid_in: str | None = "equiangular", + grid_out: str | None = "equiangular", + theta_cutoff: float | None = 0.01 * math.pi, + theta_eps: float | None = 1e-3, + transpose_normalization: bool | None = False, + basis_norm_mode: str | None = "mean", + merge_quadrature: bool | None = False, +): + r"""Precomputes the rotated filters at positions using YZY Euler angles.""" + assert len(in_shape) == 2 + assert len(out_shape) == 2 + + kernel_size = filter_basis.kernel_size + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + lats_in, win = precompute_latitudes(nlat_in, grid=grid_in) + lats_out, wout = precompute_latitudes(nlat_out, grid=grid_out) + lons_in = precompute_longitudes(nlon_in) + + if transpose_normalization: + quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0 + else: + quad_weights = win.reshape(-1, 1) / nlon_in / 2.0 + + theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff + + out_idx = [] + out_vals = [] + + beta = lons_in + gamma = lats_in.reshape(-1, 1) + + cbeta = torch.cos(beta) + sbeta = torch.sin(beta) + cgamma = torch.cos(gamma) + sgamma = torch.sin(gamma) + + out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device) + out_roff[0] = 0 + for t in range(nlat_out): + alpha = -lats_out[t] + + x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha) + y = sbeta * sgamma + z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma + + norm = torch.sqrt(x * x + y * y + z * z) + x = x / norm + y = y / norm + z = z / norm + + theta = torch.arccos(z) + phi = torch.arctan2(y, x) + phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi) + + iidx, vals = filter_basis.compute_support_vals( + theta, phi, r_cutoff=theta_cutoff_eff + ) + + idx = torch.stack( + [ + iidx[:, 0], + t * torch.ones_like(iidx[:, 0]), + iidx[:, 1] * nlon_in + iidx[:, 2], + ], + dim=0, + ) + + out_idx.append(idx) + out_vals.append(vals) + out_roff[t + 1] = out_roff[t] + iidx.shape[0] + + out_idx = torch.cat(out_idx, dim=-1) + out_vals = torch.cat(out_vals, dim=-1) + + out_vals = _normalize_convolution_tensor_s2( + out_idx, + out_vals, + in_shape, + out_shape, + kernel_size, + quad_weights, + transpose_normalization=transpose_normalization, + basis_norm_mode=basis_norm_mode, + merge_quadrature=merge_quadrature, + ) + + out_idx = out_idx.contiguous() + out_vals = out_vals.to(dtype=torch.float32).contiguous() + + return out_idx, out_vals, out_roff + + +class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): + """Abstract base class for discrete-continuous convolutions.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_shape: int | tuple[int, ...], + basis_type: str = "piecewise linear", + groups: int = 1, + bias: bool = True, + ): + super().__init__() + + self.kernel_shape = kernel_shape + self.filter_basis = _filter_basis_module.get_filter_basis( + kernel_shape=kernel_shape, basis_type=basis_type + ) + + self.groups = groups + + if in_channels % self.groups != 0: + raise ValueError( + "Error, the number of input channels has to be an integer " + "multiple of the group size" + ) + if out_channels % self.groups != 0: + raise ValueError( + "Error, the number of output channels has to be an integer " + "multiple of the group size" + ) + self.groupsize = in_channels // self.groups + scale = math.sqrt(1.0 / self.groupsize / self.kernel_size) + self.weight = nn.Parameter( + scale * torch.randn(out_channels, self.groupsize, self.kernel_size) + ) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + @property + def kernel_size(self): + return self.filter_basis.kernel_size + + @abc.abstractmethod + def forward(self, x: torch.Tensor): + raise NotImplementedError + + +class DiscreteContinuousConvS2(DiscreteContinuousConv): + """Discrete-continuous (DISCO) convolution on the 2-Sphere. + + Uses an FFT-based contraction for efficient computation. Forked from + torch-harmonics; see Ocampo, Price, McEwen, *Scalable and equivariant + spherical CNNs by discrete-continuous (DISCO) convolutions*, ICLR 2023. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + in_shape: tuple[int, int], + out_shape: tuple[int, int], + kernel_shape: int | tuple[int, ...], + basis_type: str = "piecewise linear", + basis_norm_mode: str = "mean", + groups: int = 1, + grid_in: str = "equiangular", + grid_out: str = "equiangular", + bias: bool = True, + theta_cutoff: float | None = None, + ): + super().__init__( + in_channels, out_channels, kernel_shape, basis_type, groups, bias + ) + + self.nlat_in, self.nlon_in = in_shape + self.nlat_out, self.nlon_out = out_shape + + assert self.nlon_in % self.nlon_out == 0 + + if theta_cutoff is None: + self.theta_cutoff = torch.pi / float(self.nlat_out - 1) + else: + self.theta_cutoff = theta_cutoff + + if self.theta_cutoff <= 0.0: + raise ValueError("Error, theta_cutoff has to be positive.") + + idx, vals, _ = _precompute_convolution_tensor_s2( + in_shape, + out_shape, + self.filter_basis, + grid_in=grid_in, + grid_out=grid_out, + theta_cutoff=self.theta_cutoff, + transpose_normalization=False, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + # Store index tensors as plain attributes instead of buffers so that + # DDP's _broadcast_coalesced (which cannot handle integer dtypes) does + # not attempt to synchronise them. They are deterministic across ranks + # so synchronisation is unnecessary. Device tracking is handled by the + # _apply override below. + self.psi_ker_idx = ker_idx + self.psi_row_idx = row_idx + self.psi_col_idx = col_idx + self.register_buffer("psi_vals", vals, persistent=False) + + # Precompute banded FFT of psi for FFT-based contraction + psi_sparse = _get_psi( + self.kernel_size, + self.psi_idx, + self.psi_vals, + self.nlat_in, + self.nlon_in, + self.nlat_out, + self.nlon_out, + ) + psi_fft_conj, gather_idx = _precompute_psi_banded( + psi_sparse, self.nlat_in, self.nlon_in + ) + self.register_buffer("psi_fft_conj", psi_fft_conj, persistent=False) + self.psi_gather_idx = gather_idx + + def _apply(self, fn, recurse=True): + # Required because integer index tensors are stored as plain attributes + # (not buffers) to avoid DDP's _broadcast_coalesced failing on non-float + # dtypes. This ensures they follow .to() / .cuda() device transfers. + super()._apply(fn, recurse=recurse) + self.psi_ker_idx = fn(self.psi_ker_idx) + self.psi_row_idx = fn(self.psi_row_idx) + self.psi_col_idx = fn(self.psi_col_idx) + self.psi_gather_idx = fn(self.psi_gather_idx) + return self + + def extra_repr(self): + return ( + f"in_shape={(self.nlat_in, self.nlon_in)}, " + f"out_shape={(self.nlat_out, self.nlon_out)}, " + f"in_chans={self.groupsize * self.groups}, " + f"out_chans={self.weight.shape[0]}, " + f"filter_basis={self.filter_basis}, " + f"kernel_shape={self.kernel_shape}, " + f"theta_cutoff={self.theta_cutoff}, " + f"groups={self.groups}" + ) + + @property + def psi_idx(self): + return torch.stack( + [self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0 + ).contiguous() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = _disco_s2_contraction_fft( + x, + self.psi_fft_conj.to(x.device), + self.psi_gather_idx.to(x.device), + self.nlon_out, + ) + + B, C, K, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize, K, H, W) + + out = torch.einsum( + "bgckxy,gock->bgoxy", + x, + self.weight.reshape( + self.groups, -1, self.weight.shape[1], self.weight.shape[2] + ), + ).contiguous() + out = out.reshape(B, -1, H, W) + + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1, 1) + + return out diff --git a/fme/core/disco/_disco_utils.py b/fme/core/disco/_disco_utils.py new file mode 100644 index 000000000..381a25d54 --- /dev/null +++ b/fme/core/disco/_disco_utils.py @@ -0,0 +1,137 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause +# +# Modifications: stripped down to FFT-based contraction only, removing +# CUDA kernel wrappers and torch sparse-matrix fallbacks. + + +import torch + +from ._fft import irfft, rfft + + +def _get_psi( + kernel_size: int, + psi_idx: torch.Tensor, + psi_vals: torch.Tensor, + nlat_in: int, + nlon_in: int, + nlat_out: int, + nlon_out: int, + nlat_in_local: int | None = None, + nlat_out_local: int | None = None, + semi_transposed: bool | None = False, +): + """Creates a sparse tensor for spherical harmonic convolution operations.""" + nlat_in_local = nlat_in_local if nlat_in_local is not None else nlat_in + nlat_out_local = nlat_out_local if nlat_out_local is not None else nlat_out + + if semi_transposed: + tout = psi_idx[2] // nlon_out + pout = psi_idx[2] % nlon_out + pout = nlon_out - 1 - pout + tin = psi_idx[1] + idx = torch.stack([psi_idx[0], tout, tin * nlon_out + pout], dim=0) + psi = torch.sparse_coo_tensor( + idx, psi_vals, size=(kernel_size, nlat_out_local, nlat_in_local * nlon_out) + ).coalesce() + else: + psi = torch.sparse_coo_tensor( + psi_idx, + psi_vals, + size=(kernel_size, nlat_out_local, nlat_in_local * nlon_in), + ).coalesce() + return psi + + +def _precompute_psi_banded(psi_sparse: torch.Tensor, nlat_in: int, nlon: int): + """Build a banded dense representation of psi directly from sparse COO data. + + Instead of densifying the full (K, nlat_out, nlat_in, nlon) tensor, this + only stores the contiguous band of input latitudes that have nonzero entries + for each output latitude, reducing memory by ~nlat_in/band_width. + + Returns: + ------- + psi_banded_fft_conj : (K, nlat_out, max_bw, nfreq) complex tensor + gather_idx : (nlat_out, max_bw) long tensor of input latitude indices + """ + K, nlat_out, _ = psi_sparse.shape + psi = psi_sparse.coalesce() + indices = psi.indices() # (3, nnz) + values = psi.values() # (nnz,) + + ker_idx = indices[0] + row_idx = indices[1] # output lat + col_idx = indices[2] # input_lat * nlon + lon + input_lat = col_idx // nlon + input_lon = col_idx % nlon + + # Find min/max input lat per output lat (across all kernel indices) + lat_min = torch.full((nlat_out,), nlat_in, dtype=torch.long) + lat_max = torch.full((nlat_out,), 0, dtype=torch.long) + lat_min.scatter_reduce_(0, row_idx, input_lat, reduce="amin") + lat_max.scatter_reduce_(0, row_idx, input_lat, reduce="amax") + + # Handle empty rows + empty = lat_min >= nlat_in + lat_min[empty] = 0 + lat_max[empty] = 0 + + max_bw = (lat_max - lat_min + 1).max().item() + + # Build banded tensor from sparse entries (no full densification) + psi_banded = torch.zeros(K, nlat_out, max_bw, nlon, dtype=values.dtype) + banded_lat = input_lat - lat_min[row_idx] + psi_banded[ker_idx, row_idx, banded_lat, input_lon] = values + + # Precompute FFT and gather index + psi_banded_fft_conj = rfft(psi_banded, dim=-1).conj() + gather_idx = lat_min.unsqueeze(1) + torch.arange(max_bw).unsqueeze(0) + gather_idx = gather_idx.clamp(max=nlat_in - 1) + + return psi_banded_fft_conj, gather_idx + + +def _disco_s2_contraction_fft( + x: torch.Tensor, + psi_fft_conj: torch.Tensor, + gather_idx: torch.Tensor, + nlon_out: int, +): + """FFT-based DISCO S2 contraction using banded psi representation. + + Parameters + ---------- + x : (B, C, nlat_in, nlon_in) + psi_fft_conj : (K, nlat_out, bw, nfreq) + gather_idx : (nlat_out, bw) + nlon_out : int + + Returns: + ------- + (B, C, K, nlat_out, nlon_out) + """ + batch_size, n_chans, nlat_in, nlon_in = x.shape + kernel_size, nlat_out, bw, nfreq = psi_fft_conj.shape + pscale = nlon_in // nlon_out + + # FFT of input along longitude + X_f = rfft(x.to(torch.float32), dim=-1) # (B, C, nlat_in, nfreq) + X_f = X_f.reshape(batch_size * n_chans, nlat_in, nfreq) + + # Gather relevant input lats for each output lat + X_f_gathered = X_f[:, gather_idx, :] # (B*C, nlat_out, bw, nfreq) + + # Cross-correlate: einsum over band width and frequency, then irfft + Y_f = torch.einsum("kowf,bowf->bkof", psi_fft_conj, X_f_gathered) + + # Inverse FFT + y = irfft(Y_f, n=nlon_in, dim=-1) # (B*C, K, nlat_out, nlon_in) + + # Subsample for stride + y = y[..., ::pscale] # (B*C, K, nlat_out, nlon_out) + + y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out).contiguous() + return y.to(x.dtype) diff --git a/fme/core/disco/_fft.py b/fme/core/disco/_fft.py new file mode 100644 index 000000000..5632908bb --- /dev/null +++ b/fme/core/disco/_fft.py @@ -0,0 +1,51 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.fft as fft +import torch.nn as nn + + +def _pad_dim_right( + x: torch.Tensor, dim: int, target_size: int, value: float = 0.0 +) -> torch.Tensor: + """Pad tensor along a single dimension to target_size (right-side only).""" + ndim = x.ndim + dim = dim if dim >= 0 else ndim + dim + pad_amount = target_size - x.shape[dim] + pad_spec = [0] * (2 * ndim) + pad_spec[(ndim - 1 - dim) * 2 + 1] = pad_amount + return nn.functional.pad(x, tuple(pad_spec), value=value) + + +def rfft( + x: torch.Tensor, nmodes: int | None = None, dim: int = -1, **kwargs +) -> torch.Tensor: + """Real FFT with correct padding/truncation of modes.""" + if "n" in kwargs: + raise ValueError("The 'n' argument is not allowed. Use 'nmodes' instead.") + + x = fft.rfft(x, dim=dim, **kwargs) + + if nmodes is not None and nmodes > x.shape[dim]: + x = _pad_dim_right(x, dim, nmodes, value=0.0) + elif nmodes is not None and nmodes < x.shape[dim]: + x = x.narrow(dim, 0, nmodes) + + return x + + +def irfft( + x: torch.Tensor, n: int | None = None, dim: int = -1, **kwargs +) -> torch.Tensor: + """Inverse real FFT with Hermitian symmetry enforcement.""" + if n is None: + n = 2 * (x.size(dim) - 1) + + x[..., 0].imag = 0.0 + if (n % 2 == 0) and (n // 2 < x.size(dim)): + x[..., n // 2].imag = 0.0 + + x = fft.irfft(x, n=n, dim=dim, **kwargs) + return x diff --git a/fme/core/disco/_filter_basis.py b/fme/core/disco/_filter_basis.py new file mode 100644 index 000000000..372c9ca1a --- /dev/null +++ b/fme/core/disco/_filter_basis.py @@ -0,0 +1,363 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause + +import abc +import math + +import torch + +from ._cache import lru_cache + + +def _circle_dist(x1: torch.Tensor, x2: torch.Tensor): + return torch.minimum( + torch.abs(x1 - x2), torch.abs(2 * math.pi - torch.abs(x1 - x2)) + ) + + +def _log_factorial(x: torch.Tensor): + return torch.lgamma(x + 1) + + +def _factorial(x: torch.Tensor): + return torch.exp(_log_factorial(x)) + + +class FilterBasis(metaclass=abc.ABCMeta): + """Abstract base class for a filter basis.""" + + kernel_shape: int | list[int] | tuple[int, ...] + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + self.kernel_shape = kernel_shape + + def __repr__(self): + class_name = self.__class__.__name__ + if hasattr(self, "extra_repr"): + return f"{class_name}({self.extra_repr()})" + else: + return f"{class_name}()" + + def extra_repr(self): + return f"kernel_shape={self.kernel_shape}" + + @property + @abc.abstractmethod + def kernel_size(self): + raise NotImplementedError + + @abc.abstractmethod + def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): + raise NotImplementedError + + +@lru_cache(typed=True, copy=False) +def get_filter_basis( + kernel_shape: int | tuple[int, ...], basis_type: str +) -> FilterBasis: + """Factory function to generate the appropriate filter basis.""" + if basis_type == "piecewise linear": + return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "morlet": + return MorletFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "isotropic morlet": + return IsotropicMorletFilterBasis(kernel_shape=kernel_shape) + elif basis_type == "zernike": + return ZernikeFilterBasis(kernel_shape=kernel_shape) + else: + raise ValueError(f"Unknown basis_type {basis_type}") + + +class PiecewiseLinearFilterBasis(FilterBasis): + """Tensor-product basis on a disk from piecewise linear basis functions.""" + + kernel_shape: list[int] + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, int): + kernel_shape = [kernel_shape] + if len(kernel_shape) == 1: + kernel_shape = [kernel_shape[0], 1] + elif len(kernel_shape) != 2: + raise ValueError( + f"expected kernel_shape to be of length 1 or 2 " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self): + return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[ + 0 + ] % 2 + + def _compute_support_vals_isotropic( + self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + nr = self.kernel_shape[0] + dr = 2 * r_cutoff / (nr + 1) + + if nr % 2 == 1: + ir = ikernel * dr + else: + ir = (ikernel + 0.5) * dr + + iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) + vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr + return iidx, vals + + def _compute_support_vals_anisotropic( + self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + nr = self.kernel_shape[0] + nphi = self.kernel_shape[1] + dr = 2 * r_cutoff / (nr + 1) + dphi = 2.0 * math.pi / nphi + + if nr % 2 == 1: + ir = ((ikernel - 1) // nphi + 1) * dr + iphi = ((ikernel - 1) % nphi) * dphi - math.pi + else: + ir = (ikernel // nphi + 0.5) * dr + iphi = (ikernel % nphi) * dphi - math.pi + + if nr % 2 == 1: + cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) + cond_phi = (ikernel == 0) | (_circle_dist(phi, iphi).abs() <= dphi) + iidx = torch.argwhere(cond_r & cond_phi) + dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() + dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) + vals = 1 - dist_r / dr + vals *= torch.where((iidx[:, 0] > 0), (1 - dist_phi / dphi), 1.0) + else: + rn = -r + phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi) + cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) + cond_phi = _circle_dist(phi, iphi).abs() <= dphi + cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) + cond_phin = _circle_dist(phin, iphi) <= dphi + iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) + + dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() + dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) + dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() + dist_phin = _circle_dist( + phin[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0] + ) + vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) + vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) + valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) + valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * ( + 1 - dist_phin / dphi + ) + vals += valsn + + return iidx, vals + + def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): + if self.kernel_shape[1] > 1: + return self._compute_support_vals_anisotropic(r, phi, r_cutoff=r_cutoff) + else: + return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff) + + +class MorletFilterBasis(FilterBasis): + """Morlet-style filter basis on the disk.""" + + kernel_shape: list[int] + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, int): + kernel_shape = [kernel_shape, kernel_shape] + if len(kernel_shape) != 2: + raise ValueError( + f"expected kernel_shape to be of length 2 " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self): + return self.kernel_shape[0] * self.kernel_shape[1] + + def hann_window(self, r: torch.Tensor, width: float = 1.0): + return torch.cos(0.5 * torch.pi * r / width) ** 2 + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 1.0, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + nkernel = ikernel % self.kernel_shape[1] + mkernel = ikernel // self.kernel_shape[1] + + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + phi = phi[iidx[:, 1], iidx[:, 2]] + x = r * torch.sin(phi) + y = r * torch.cos(phi) + n = nkernel[iidx[:, 0], 0, 0] + m = mkernel[iidx[:, 0], 0, 0] + + harmonic = torch.where( + n % 2 == 1, + torch.sin(torch.ceil(n / 2) * math.pi * x / width), + torch.cos(torch.ceil(n / 2) * math.pi * x / width), + ) + harmonic *= torch.where( + m % 2 == 1, + torch.sin(torch.ceil(m / 2) * math.pi * y / width), + torch.cos(torch.ceil(m / 2) * math.pi * y / width), + ) + + vals = self.hann_window(r, width=width) * harmonic + return iidx, vals + + +class IsotropicMorletFilterBasis(FilterBasis): + """Morlet-style filter basis using only radial modes. + + Each basis function is a product of a Hann radial window and a 1-D + Fourier harmonic in the normalised radial coordinate ``r / r_cutoff``. + Because none of the basis functions depend on the azimuthal angle + ``phi``, any learned linear combination is guaranteed to be isotropic + (radially symmetric). + + ``kernel_shape`` is a single integer giving the number of radial modes. + If a tuple is provided, only the first element is used. + """ + + kernel_shape: int + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, list | tuple): + kernel_shape = kernel_shape[0] + if not isinstance(kernel_shape, int): + raise ValueError( + f"expected kernel_shape to be an integer " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self) -> int: + return self.kernel_shape + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 1.0, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + r_norm = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + n = ikernel[iidx[:, 0], 0, 0] + + # Radial Fourier modes: cos/sin pattern identical to Morlet but in r. + harmonic = torch.where( + n % 2 == 1, + torch.sin(torch.ceil(n / 2) * math.pi * r_norm / width), + torch.cos(torch.ceil(n / 2) * math.pi * r_norm / width), + ) + + # Hann radial envelope + window = torch.cos(0.5 * torch.pi * r_norm / width) ** 2 + vals = window * harmonic + + return iidx, vals + + +class ZernikeFilterBasis(FilterBasis): + """Zernike polynomial basis defined on the disk.""" + + kernel_shape: int + + def __init__(self, kernel_shape: int | list[int] | tuple[int, ...]): + if isinstance(kernel_shape, list | tuple): + kernel_shape = kernel_shape[0] + if not isinstance(kernel_shape, int): + raise ValueError( + f"expected kernel_shape to be an integer " + f"but got {kernel_shape} instead." + ) + super().__init__(kernel_shape=kernel_shape) + + @property + def kernel_size(self): + return (self.kernel_shape * (self.kernel_shape + 1)) // 2 + + def zernikeradial(self, r: torch.Tensor, n: torch.Tensor, m: torch.Tensor): + out = torch.zeros_like(r) + bound = (n - m) // 2 + 1 + max_bound = bound.max().item() + + for k in range(max_bound): + inc = ( + (-1) ** k + * _factorial(n - k) + * r ** (n - 2 * k) + / ( + math.factorial(k) + * _factorial((n + m) // 2 - k) + * _factorial((n - m) // 2 - k) + ) + ) + out += torch.where(k < bound, inc, 0.0) + return out + + def zernikepoly( + self, + r: torch.Tensor, + phi: torch.Tensor, + n: torch.Tensor, + l: torch.Tensor, # noqa: E741 + ): + m = 2 * l - n + return torch.where( + m < 0, + self.zernikeradial(r, n, -m) * torch.sin(m * phi), + self.zernikeradial(r, n, m) * torch.cos(m * phi), + ) + + def compute_support_vals( + self, + r: torch.Tensor, + phi: torch.Tensor, + r_cutoff: float, + width: float = 0.25, + ): + ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1) + iidx = torch.argwhere( + (r <= r_cutoff) + & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + ) + + nshifts = torch.arange(self.kernel_shape, device=r.device) + nshifts = (nshifts + 1) * nshifts // 2 + nkernel = torch.searchsorted(nshifts, ikernel, right=True) - 1 + lkernel = ikernel - nshifts[nkernel] + + r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff + phi = phi[iidx[:, 1], iidx[:, 2]] + n = nkernel[iidx[:, 0], 0, 0] + l = lkernel[iidx[:, 0], 0, 0] # noqa: E741 + + vals = self.zernikepoly(r, phi, n, l) + return iidx, vals diff --git a/fme/core/disco/_quadrature.py b/fme/core/disco/_quadrature.py new file mode 100644 index 000000000..dac51540d --- /dev/null +++ b/fme/core/disco/_quadrature.py @@ -0,0 +1,101 @@ +# Forked from torch-harmonics (BSD-3-Clause) +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. +# SPDX-License-Identifier: BSD-3-Clause +# +# Subset: only the functions needed by DISCO convolution precomputation. + +import math + +import numpy as np +import torch + +from ._cache import lru_cache + + +def _trapezoidal_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + xlg = torch.as_tensor(np.linspace(a, b, n, endpoint=True)) + wlg = (b - a) / n * torch.ones(n, requires_grad=False) + wlg[0] *= 0.5 + wlg[-1] *= 0.5 + return xlg, wlg + + +def _legendre_gauss_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + xlg, wlg = np.polynomial.legendre.leggauss(n) + xlg = torch.as_tensor(xlg).clone() + wlg = torch.as_tensor(wlg).clone() + xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 + wlg = wlg * (b - a) * 0.5 + return xlg, wlg + + +def _clenshaw_curtiss_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + assert n > 1 + + tcc = torch.cos( + torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False) + ) + + if n == 2: + wcc = torch.as_tensor([1.0, 1.0], dtype=torch.float64) + else: + n1 = n - 1 + N = torch.arange(1, n1, 2, dtype=torch.float64) + ll = len(N) + m = n1 - ll + + v = torch.cat( + [ + 2 / N / (N - 2), + 1 / N[-1:], + torch.zeros(m, dtype=torch.float64, requires_grad=False), + ] + ) + v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,)) + + g0 = -torch.ones(n1, dtype=torch.float64, requires_grad=False) + g0[ll] = g0[ll] + n1 + g0[m] = g0[m] + n1 + g = g0 / (n1**2 - 1 + (n1 % 2)) + wcc = torch.fft.ifft(v + g).real + wcc = torch.cat((wcc, wcc[:1])) + + tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 + wcc = wcc * (b - a) * 0.5 + return tcc, wcc + + +def _precompute_quadrature_weights( + n: int, grid: str = "equiangular", a: float = -1.0, b: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + if grid == "equidistant": + return _trapezoidal_weights(n, a=a, b=b) + elif grid == "legendre-gauss": + return _legendre_gauss_weights(n, a=a, b=b) + elif grid == "equiangular": + return _clenshaw_curtiss_weights(n, a=a, b=b) + else: + raise ValueError(f"Unknown grid type {grid}") + + +@lru_cache(typed=True, copy=True) +def precompute_longitudes(nlon: int) -> torch.Tensor: + return torch.linspace( + 0, 2 * math.pi, nlon + 1, dtype=torch.float64, requires_grad=False + )[:-1] + + +@lru_cache(typed=True, copy=True) +def precompute_latitudes( + nlat: int, grid: str = "equiangular" +) -> tuple[torch.Tensor, torch.Tensor]: + xlg, wlg = _precompute_quadrature_weights(nlat, grid=grid, a=-1.0, b=1.0) + lats = torch.flip(torch.arccos(xlg), dims=(0,)).clone() + wlg = torch.flip(wlg, dims=(0,)).clone() + return lats, wlg diff --git a/fme/core/disco/test_disco.py b/fme/core/disco/test_disco.py new file mode 100644 index 000000000..5a1729560 --- /dev/null +++ b/fme/core/disco/test_disco.py @@ -0,0 +1,120 @@ +import torch +import torch_harmonics as th + +from fme.core.disco import DiscreteContinuousConvS2 + + +def test_forward_shape(): + """Basic smoke test: output has the expected shape.""" + img_shape = (16, 32) + in_channels = 4 + out_channels = 8 + conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + grid_in="equiangular", + grid_out="equiangular", + bias=True, + theta_cutoff=0.5, + ) + x = torch.randn(2, in_channels, *img_shape) + with torch.no_grad(): + y = conv(x) + assert y.shape == (2, out_channels, *img_shape) + + +def test_matches_torch_harmonics_reference(): + """Output matches the torch-harmonics sparse-matrix implementation.""" + torch.manual_seed(0) + img_shape = (16, 32) + in_channels = 4 + out_channels = 4 + conv_kwargs = dict( + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + groups=1, + grid_in="equiangular", + grid_out="equiangular", + bias=False, + theta_cutoff=0.5, + ) + + fft_conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + **conv_kwargs, # type: ignore[arg-type] + ) + ref_conv = th.DiscreteContinuousConvS2( + in_channels, + out_channels, + **conv_kwargs, # type: ignore[arg-type] + ) + + # Copy weights from fft_conv to ref_conv so they use the same parameters + ref_conv.weight.data.copy_(fft_conv.weight.data) + + x = torch.randn(2, in_channels, *img_shape) + with torch.no_grad(): + y_fft = fft_conv(x) + y_ref = ref_conv(x) + + torch.testing.assert_close(y_fft, y_ref, atol=1e-4, rtol=1e-4) + + +def test_grouped_convolution(): + """Test that grouped convolutions work.""" + img_shape = (16, 32) + in_channels = 8 + out_channels = 8 + groups = 4 + conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + groups=groups, + grid_in="equiangular", + grid_out="equiangular", + bias=False, + theta_cutoff=0.5, + ) + x = torch.randn(2, in_channels, *img_shape) + with torch.no_grad(): + y = conv(x) + assert y.shape == (2, out_channels, *img_shape) + + +def test_backward_pass(): + """Gradients flow through the FFT-based convolution.""" + img_shape = (16, 32) + in_channels = 4 + out_channels = 4 + conv = DiscreteContinuousConvS2( + in_channels, + out_channels, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=(3, 3), + basis_type="morlet", + basis_norm_mode="mean", + grid_in="equiangular", + grid_out="equiangular", + bias=True, + theta_cutoff=0.5, + ) + x = torch.randn(2, in_channels, *img_shape) + y = conv(x) + y.sum().backward() + assert conv.weight.grad is not None + assert conv.bias.grad is not None diff --git a/fme/core/distributed/non_distributed.py b/fme/core/distributed/non_distributed.py index 4814efb68..01a44d60f 100644 --- a/fme/core/distributed/non_distributed.py +++ b/fme/core/distributed/non_distributed.py @@ -115,7 +115,9 @@ def get_isht( return th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid).float() def get_disco_conv_s2(self, *args, **kwargs) -> nn.Module: - return th.DiscreteContinuousConvS2(*args, **kwargs).float() + from fme.core.disco import DiscreteContinuousConvS2 + + return DiscreteContinuousConvS2(*args, **kwargs).float() def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: return tensor diff --git a/fme/core/distributed/torch_distributed.py b/fme/core/distributed/torch_distributed.py index 632c12b1c..9c6da212a 100644 --- a/fme/core/distributed/torch_distributed.py +++ b/fme/core/distributed/torch_distributed.py @@ -207,7 +207,9 @@ def get_isht( return th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid).float() def get_disco_conv_s2(self, *args, **kwargs) -> nn.Module: - return th.DiscreteContinuousConvS2(*args, **kwargs).float() + from fme.core.disco import DiscreteContinuousConvS2 + + return DiscreteContinuousConvS2(*args, **kwargs).float() def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: return tensor diff --git a/fme/core/models/conditional_sfno/localnet.py b/fme/core/models/conditional_sfno/localnet.py index dad4a4be0..1d7bddae6 100644 --- a/fme/core/models/conditional_sfno/localnet.py +++ b/fme/core/models/conditional_sfno/localnet.py @@ -14,7 +14,7 @@ from .sfnonet import DiscreteContinuousConvS2, NoLayerNorm, _compute_cutoff_radius BlockType = Literal["disco", "conv1x1"] -BasisType = Literal["morlet", "piecewise linear", "zernike"] +BasisType = Literal["morlet", "isotropic morlet", "piecewise linear", "zernike"] @dataclasses.dataclass @@ -27,7 +27,7 @@ class LocalNetConfig: kernel_shape: Shape of the DISCO convolution filter basis, passed to the filter basis constructor. Only affects 'disco' blocks. basis_type: Type of filter basis for the DISCO convolution - ('morlet', 'piecewise linear', or 'zernike'). + ('morlet', 'isotropic morlet', 'piecewise linear', or 'zernike'). Only affects 'disco' blocks. block_types: List of filter types for each block ('disco', 'conv1x1'). The length determines the number of blocks. diff --git a/fme/core/models/conditional_sfno/sfnonet.py b/fme/core/models/conditional_sfno/sfnonet.py index 1dbc5e7a8..36f02a1e6 100644 --- a/fme/core/models/conditional_sfno/sfnonet.py +++ b/fme/core/models/conditional_sfno/sfnonet.py @@ -119,6 +119,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): theta_cutoff_factor = { "piecewise linear": 0.5, "morlet": 0.5, + "isotropic morlet": 0.5, "zernike": math.sqrt(2.0), } diff --git a/fme/core/models/conditional_sfno/test_localnet.py b/fme/core/models/conditional_sfno/test_localnet.py index f9eab8aa7..f6480fad2 100644 --- a/fme/core/models/conditional_sfno/test_localnet.py +++ b/fme/core/models/conditional_sfno/test_localnet.py @@ -9,6 +9,7 @@ from .layers import Context, ContextConfig from .localnet import LocalNetConfig, get_lat_lon_localnet +from .sfnonet import DiscreteContinuousConvS2, _compute_cutoff_radius DIR = os.path.dirname(os.path.abspath(__file__)) @@ -234,6 +235,8 @@ def test_no_big_skip(): ((3, 5), "morlet"), ((3, 3), "piecewise linear"), ((3, 1), "piecewise linear"), + ((3, 3), "isotropic morlet"), + ((5, 3), "isotropic morlet"), ], ) def test_can_call_localnet_with_kernel_shape(kernel_shape, basis_type): @@ -265,6 +268,71 @@ def test_can_call_localnet_with_kernel_shape(kernel_shape, basis_type): assert output.shape == (n_samples, output_channels, *img_shape) +def test_isotropic_disco_conv_commutes_with_latitude_reflection(): + """An isotropic DISCO conv must commute with latitude reflection. + + Latitude reflection (θ → π−θ) is an isometry of the sphere that maps + equiangular grid points exactly onto other grid points. For an + isotropic filter the convolution kernel depends only on geodesic + distance, which is preserved by the reflection, so + ``flip(conv(x)) == conv(flip(x))``. An anisotropic filter also + depends on the local azimuthal angle, which the reflection reverses, + so the equality generally fails. + """ + torch.manual_seed(0) + img_shape = (16, 32) + embed_dim = 4 + kernel_shape = (3, 3) + data_grid = "equiangular" + lat_dim = 2 # dimension index for latitude in (batch, channel, lat, lon) + + def make_disco_conv(basis_type): + theta_cutoff = 2 * _compute_cutoff_radius( + nlat=img_shape[0], + kernel_shape=kernel_shape, + basis_type=basis_type, + ) + return DiscreteContinuousConvS2( + embed_dim, + embed_dim, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=kernel_shape, + basis_type=basis_type, + basis_norm_mode="mean", + groups=1, + grid_in=data_grid, + grid_out=data_grid, + bias=False, + theta_cutoff=theta_cutoff, + ) + + x = torch.randn(1, embed_dim, *img_shape) + x_flipped = torch.flip(x, dims=[lat_dim]) + + # Isotropic basis: convolution must commute with the reflection. + iso_conv = make_disco_conv("isotropic morlet") + with torch.no_grad(): + out, _ = iso_conv(x) + out_from_flipped, _ = iso_conv(x_flipped) + torch.testing.assert_close( + torch.flip(out, dims=[lat_dim]), + out_from_flipped, + ) + + # Anisotropic basis: convolution should NOT commute (random weights + # activate the azimuthal modes that break the symmetry). + aniso_conv = make_disco_conv("morlet") + with torch.no_grad(): + out, _ = aniso_conv(x) + out_from_flipped, _ = aniso_conv(x_flipped) + assert not torch.allclose( + torch.flip(out, dims=[lat_dim]), + out_from_flipped, + atol=1e-5, + ) + + def test_unknown_filter_type_raises(): with pytest.raises(ValueError, match="Invalid block type"): LocalNetConfig(block_types=["spectral"]) # type: ignore[list-item] diff --git a/fme/core/step/single_module.py b/fme/core/step/single_module.py index 9eb94ee79..cefe87739 100644 --- a/fme/core/step/single_module.py +++ b/fme/core/step/single_module.py @@ -20,6 +20,7 @@ from fme.core.optimization import NullOptimization from fme.core.packer import Packer from fme.core.registry import CorrectorSelector, ModuleSelector +from fme.core.registry.module import Module from fme.core.step.args import StepArgs from fme.core.step.secondary_decoder import ( NoSecondaryDecoder, @@ -53,6 +54,15 @@ class SingleModuleStepConfig(StepConfigABC): prescribed_prognostic_names: Prognostic variable names to overwrite from forcing data at each step (e.g. for inference with observed values). residual_prediction: Whether to use residual prediction. + secondary_builder: Optional builder for a secondary network that receives + the same input as the primary module. + secondary_out_names: Names of variables output by the secondary network + as full fields (used directly as output). Must not overlap with + out_names. + secondary_residual_out_names: Names of variables for which the secondary + network predicts a residual correction. If the name is also in + out_names, the residual is added to the backbone's output; + otherwise it is added to the (normalized) input value. """ builder: ModuleSelector @@ -67,6 +77,9 @@ class SingleModuleStepConfig(StepConfigABC): next_step_forcing_names: list[str] = dataclasses.field(default_factory=list) prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list) residual_prediction: bool = False + secondary_builder: ModuleSelector | None = None + secondary_out_names: list[str] = dataclasses.field(default_factory=list) + secondary_residual_out_names: list[str] = dataclasses.field(default_factory=list) def __post_init__(self): self.crps_training = None # unused, kept for backwards compatibility @@ -85,16 +98,57 @@ def __post_init__(self): raise ValueError( f"next_step_forcing_name is an output variable: '{name}'" ) + all_secondary_names = set(self.secondary_out_names) | set( + self.secondary_residual_out_names + ) if self.secondary_decoder is not None: for name in self.secondary_decoder.secondary_diagnostic_names: if name in self.in_names: raise ValueError( f"secondary_diagnostic_name is an input variable: '{name}'" ) - if name in self.out_names: + if name in set(self.out_names) | all_secondary_names: raise ValueError( f"secondary_diagnostic_name is an output variable: '{name}'" ) + if self.secondary_builder is None: + if self.secondary_out_names: + raise ValueError( + "secondary_out_names must be empty when " + "secondary_builder is not provided" + ) + if self.secondary_residual_out_names: + raise ValueError( + "secondary_residual_out_names must be empty when " + "secondary_builder is not provided" + ) + else: + if not self.secondary_out_names and not self.secondary_residual_out_names: + raise ValueError( + "at least one of secondary_out_names or " + "secondary_residual_out_names must be non-empty when " + "secondary_builder is provided" + ) + overlap = set(self.secondary_out_names) & set(self.out_names) + if overlap: + raise ValueError( + f"secondary_out_names must not overlap with out_names. " + f"Overlap: {overlap}" + ) + overlap = set(self.secondary_out_names) & set( + self.secondary_residual_out_names + ) + if overlap: + raise ValueError( + f"secondary_out_names must not overlap with " + f"secondary_residual_out_names. Overlap: {overlap}" + ) + for name in self.secondary_residual_out_names: + if name not in self.out_names and name not in self.in_names: + raise ValueError( + f"secondary_residual_out_name '{name}' must be in " + f"out_names or in_names: {self.out_names}, {self.in_names}" + ) @property def n_ic_timesteps(self) -> int: @@ -151,12 +205,17 @@ def diagnostic_names(self) -> list[str]: @property def output_names(self) -> list[str]: - secondary_names = ( + secondary_decoder_names = ( self.secondary_decoder.secondary_diagnostic_names if self.secondary_decoder is not None else [] ) - return list(set(self.out_names).union(secondary_names)) + return list( + set(self.out_names) + .union(secondary_decoder_names) + .union(self.secondary_out_names) + .union(self.secondary_residual_out_names) + ) @property def next_step_input_names(self) -> list[str]: @@ -267,6 +326,21 @@ def __init__( dist = Distributed.get_instance() + if config.secondary_builder is not None: + all_secondary_names = ( + config.secondary_out_names + config.secondary_residual_out_names + ) + secondary_module = config.secondary_builder.build( + n_in_channels=n_in_channels, + n_out_channels=len(all_secondary_names), + dataset_info=dataset_info, + ) + self.secondary_module: Module | None = secondary_module.to(get_device()) + self.secondary_out_packer: Packer | None = Packer(all_secondary_names) + else: + self.secondary_module = None + self.secondary_out_packer = None + if config.secondary_decoder is not None: self.secondary_decoder: SecondaryDecoder | NoSecondaryDecoder = ( config.secondary_decoder.build( @@ -282,6 +356,8 @@ def __init__( self._no_optimization = NullOptimization() self.module = self.module.wrap_module(dist.wrap_module) + if self.secondary_module is not None: + self.secondary_module = self.secondary_module.wrap_module(dist.wrap_module) self.secondary_decoder = self.secondary_decoder.wrap_module(dist.wrap_module) self._timestep = dataset_info.timestep @@ -329,6 +405,8 @@ def modules(self) -> nn.ModuleList: A list of modules being trained. """ modules = [self.module.torch_module] + if self.secondary_module is not None: + modules.append(self.secondary_module.torch_module) modules.extend(self.secondary_decoder.torch_modules) return nn.ModuleList(modules) @@ -355,6 +433,23 @@ def network_call(input_norm: TensorDict) -> TensorDict: labels=args.labels, ) output_dict = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM) + if self.secondary_module is not None: + if self.secondary_out_packer is None: + raise RuntimeError("secondary_out_packer is unexpectedly None") + secondary_tensor = self.secondary_module.wrap_module(wrapper)( + input_tensor, + labels=args.labels, + ) + secondary_dict = self.secondary_out_packer.unpack( + secondary_tensor, axis=self.CHANNEL_DIM + ) + for name in self._config.secondary_out_names: + output_dict[name] = secondary_dict[name] + for name in self._config.secondary_residual_out_names: + if name in output_dict: + output_dict[name] = output_dict[name] + secondary_dict[name] + else: + output_dict[name] = input_norm[name] + secondary_dict[name] secondary_output_dict = self.secondary_decoder.wrap_module(wrapper)( output_tensor.detach() # detach avoids changing base outputs ) @@ -385,6 +480,8 @@ def get_state(self): "module": self.module.get_state(), "secondary_decoder": self.secondary_decoder.get_module_state(), } + if self.secondary_module is not None: + state["secondary_module"] = self.secondary_module.get_state() return state def load_state(self, state: dict[str, Any]) -> None: @@ -399,6 +496,8 @@ def load_state(self, state: dict[str, Any]) -> None: # for backwards compatibility with old checkpoints del module["module.device_buffer"] self.module.load_state(module) + if "secondary_module" in state and self.secondary_module is not None: + self.secondary_module.load_state(state["secondary_module"]) if "secondary_decoder" in state: self.secondary_decoder.load_module_state(state["secondary_decoder"]) diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index 5d98df1cf..f2271eaa9 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -374,6 +374,41 @@ def get_fcn3_selector( ) +def get_single_module_with_secondary_selector( + dir: pathlib.Path | None = None, +) -> StepSelector: + normalization = get_network_and_loss_normalization_config( + names=[ + "forcing_a", + "prog_a", + "prog_b", + "diag_a", + ], + dir=dir, + ) + return StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing_a", "prog_a", "prog_b"], + out_names=["prog_a", "prog_b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["diag_a"], + secondary_residual_out_names=["prog_a"], + ), + ), + ) + + def get_multi_call_selector( dir: pathlib.Path | None = None, ) -> StepSelector: @@ -403,6 +438,7 @@ def get_multi_call_selector( get_separate_radiation_selector, get_single_module_selector, get_single_module_noise_conditioned_selector, + get_single_module_with_secondary_selector, get_multi_call_selector, ] @@ -666,6 +702,289 @@ def test_input_output_names_secondary_decoder_conflict(conflict: str): assert f"secondary_diagnostic_name is an {conflict} variable:" in str(err.value) +def test_secondary_builder_none_with_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="secondary_out_names must be empty"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_out_names=["c"], + ) + + +def test_secondary_builder_none_with_residual_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="secondary_residual_out_names must be empty"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_residual_out_names=["b"], + ) + + +def test_secondary_builder_with_empty_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="at least one of"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + ) + + +def test_secondary_out_name_overlaps_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + with pytest.raises(ValueError, match="secondary_out_names must not overlap"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["b"], + ) + + +def test_secondary_out_name_overlaps_residual_out_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises( + ValueError, match="secondary_out_names must not overlap.*residual" + ): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["c"], + secondary_residual_out_names=["c"], + ) + + +def test_secondary_residual_out_name_not_in_out_or_in_names_raises(): + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + with pytest.raises(ValueError, match="secondary_residual_out_name 'c'"): + SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["c"], + ) + + +@pytest.mark.parallel +def test_secondary_network_full_field_and_residual(): + """Test secondary_out_names and secondary_residual_out_names together.""" + torch.manual_seed(0) + normalization = get_network_and_loss_normalization_config( + names=["forcing", "prog", "diag"], + ) + config = StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing", "prog"], + out_names=["prog"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["diag"], + secondary_residual_out_names=["prog"], + ), + ), + ) + img_shape = DEFAULT_IMG_SHAPE + step = get_step(config, img_shape) + assert "prog" in step.output_names + assert "diag" in step.output_names + assert "prog" in step.prognostic_names + input_data = get_tensor_dict(step.input_names, img_shape, n_samples=2) + next_step_input_data = get_tensor_dict( + step.next_step_input_names, img_shape, n_samples=2 + ) + output = step.step( + args=StepArgs( + input=input_data, next_step_input_data=next_step_input_data, labels=None + ), + ) + assert "prog" in output + assert "diag" in output + assert output["prog"].shape == (2, *img_shape) + assert output["diag"].shape == (2, *img_shape) + + +@pytest.mark.parallel +def test_secondary_network_state_round_trip(): + """Test get_state/load_state with secondary module.""" + torch.manual_seed(0) + normalization = get_network_and_loss_normalization_config( + names=["forcing", "prog", "diag"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing", "prog"], + out_names=["prog"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["diag"], + secondary_residual_out_names=["prog"], + ) + img_shape = DEFAULT_IMG_SHAPE + step1 = get_step( + StepSelector(type="single_module", config=dataclasses.asdict(config)), + img_shape, + ) + state = step1.get_state() + assert "secondary_module" in state + + step2 = get_step( + StepSelector(type="single_module", config=dataclasses.asdict(config)), + img_shape, + ) + step2.load_state(state) + + input_data = get_tensor_dict(step1.input_names, img_shape, n_samples=1) + next_step_input_data = get_tensor_dict( + step1.next_step_input_names, img_shape, n_samples=1 + ) + args = StepArgs( + input=input_data, next_step_input_data=next_step_input_data, labels=None + ) + out1 = step1.step(args=args) + out2 = step2.step(args=args) + for name in out1: + torch.testing.assert_close(out1[name], out2[name]) + + +def test_secondary_output_names_full_field_only(): + """secondary_out_names appear in output_names.""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b", "c"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_out_names=["c"], + ) + assert "c" in config.output_names + assert "b" in config.output_names + + +def test_secondary_output_names_residual_only(): + """secondary_residual_out_names that are in out_names appear in output_names.""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["b"], + ) + assert "b" in config.output_names + + +def test_secondary_output_names_residual_on_input_only(): + """secondary_residual_out_names on input-only name adds to output_names.""" + normalization = get_network_and_loss_normalization_config( + names=["a", "b"], + ) + config = SingleModuleStepConfig( + builder=ModuleSelector(type="MLP", config={}), + in_names=["a", "b"], + out_names=["b"], + normalization=normalization, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["a"], + ) + assert "a" in config.output_names + assert "b" in config.output_names + + +@pytest.mark.parallel +def test_secondary_residual_on_input_only_with_residual_prediction(): + """When residual_prediction=True and secondary_residual_out_name is in in_names + but not out_names, the input should not be added twice.""" + torch.manual_seed(0) + normalization = get_network_and_loss_normalization_config( + names=["forcing", "prog_a", "prog_b"], + ) + config = StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="SphericalFourierNeuralOperatorNet", + config={ + "scale_factor": 1, + "embed_dim": 4, + "num_layers": 2, + }, + ), + in_names=["forcing", "prog_a", "prog_b"], + out_names=["prog_a", "prog_b"], + normalization=normalization, + residual_prediction=True, + secondary_builder=ModuleSelector(type="MLP", config={}), + secondary_residual_out_names=["prog_a"], + ), + ), + ) + img_shape = DEFAULT_IMG_SHAPE + step = get_step(config, img_shape) + assert "prog_a" in step.output_names + input_data = get_tensor_dict(step.input_names, img_shape, n_samples=2) + next_step_input_data = get_tensor_dict( + step.next_step_input_names, img_shape, n_samples=2 + ) + output = step.step( + args=StepArgs( + input=input_data, next_step_input_data=next_step_input_data, labels=None + ), + ) + assert output["prog_a"].shape == (2, *img_shape) + assert output["prog_b"].shape == (2, *img_shape) + + def test_step_with_prescribed_prognostic_overwrites_output(): normalization = get_network_and_loss_normalization_config( names=["forcing_shared", "forcing_rad", "diagnostic_main", "diagnostic_rad"],