Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion fme/ace/step/fcn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
AtmoSphericNeuralOperatorNet,
)
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig
from fme.core.corrector.registry import CorrectorABC
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
from fme.core.dataset.utils import encode_timestep
from fme.core.dataset_info import DatasetInfo
from fme.core.device import get_device
Expand Down Expand Up @@ -177,6 +177,18 @@ class FCN3StepConfig(StepConfigABC):
residual_prediction: bool = False

def __post_init__(self):
corrector: CorrectorConfigABC = self.corrector
if isinstance(corrector, CorrectorSelector):
corrector = corrector.config_instance
if (
isinstance(corrector, AtmosphereCorrectorConfig)
and corrector.ocean is not None
):
raise ValueError(
"FCN3StepConfig manages ocean configuration via its own 'ocean' "
"attribute. Configuring 'ocean' on the AtmosphereCorrector is not "
"supported."
)
for name in self.next_step_forcing_names:
if name not in self.forcing_names:
raise ValueError(
Expand Down
13 changes: 13 additions & 0 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
VerticalCoordinate,
)
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig
from fme.core.corrector.registry import CorrectorConfigABC
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset.schedule import IntSchedule
from fme.core.dataset.utils import encode_timestep
Expand Down Expand Up @@ -133,6 +134,18 @@ class SingleModuleStepperConfig:
residual_prediction: bool = False

def __post_init__(self):
corrector: CorrectorConfigABC = self.corrector
if isinstance(corrector, CorrectorSelector):
corrector = corrector.config_instance
if (
isinstance(corrector, AtmosphereCorrectorConfig)
and corrector.ocean is not None
):
raise ValueError(
"SingleModuleStepperConfig manages ocean configuration via its own "
"'ocean' attribute. Configuring 'ocean' on AtmosphereCorrector is not "
"supported."
)
for name in self.prescribed_prognostic_names:
if name not in self.out_names:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion fme/core/atmosphere_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Protocol
from typing import Protocol, runtime_checkable

import torch

Expand Down Expand Up @@ -40,6 +40,7 @@
}


@runtime_checkable
class HasAtmosphereVerticalIntegral(Protocol):
def vertical_integral(
self,
Expand Down
124 changes: 12 additions & 112 deletions fme/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,21 @@
import math
from collections.abc import Callable, Mapping
from datetime import timedelta
from typing import Literal, TypeVar
from typing import Literal, TypeVar, final

import dacite
import numpy as np
import torch

from fme.core import metrics
from fme.core.constants import EARTH_RADIUS, GRAVITY
from fme.core.corrector.atmosphere import AtmosphereCorrector, AtmosphereCorrectorConfig
from fme.core.corrector.ice import IceCorrector, IceCorrectorConfig
from fme.core.corrector.ocean import OceanCorrector, OceanCorrectorConfig
from fme.core.corrector.registry import CorrectorABC
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
from fme.core.derived_variables import compute_derived_quantities
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations
from fme.core.mask_provider import MaskProvider, MaskProviderABC, NullMaskProvider
from fme.core.ocean_derived_variables import compute_ocean_derived_quantities
from fme.core.registry.corrector import CorrectorSelector
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.winds import lon_lat_to_xyz

Expand Down Expand Up @@ -136,14 +132,21 @@ def __repr__(self) -> str:
def __eq__(self, other) -> bool:
pass

@abc.abstractmethod
@final
def build_corrector(
self,
config: AtmosphereCorrectorConfig | CorrectorSelector,
config: CorrectorConfigABC,
gridded_operations: GriddedOperations,
timestep: timedelta,
) -> CorrectorABC:
pass
vertical_coord: VerticalCoordinate | None = self
if isinstance(vertical_coord, NullVerticalCoordinate):
vertical_coord = None
return config.get_corrector(
gridded_operations=gridded_operations,
vertical_coordinate=vertical_coord,
timestep=timestep,
)

@abc.abstractmethod
def build_derive_function(
Expand Down Expand Up @@ -208,35 +211,6 @@ def __len__(self):
"""The number of vertical layer interfaces."""
return len(self.ak)

def build_corrector(
self,
config: AtmosphereCorrectorConfig | CorrectorSelector,
gridded_operations: GriddedOperations,
timestep: timedelta,
) -> AtmosphereCorrector:
if (
isinstance(config, CorrectorSelector)
and config.type != "atmosphere_corrector"
):
raise ValueError(
f"Cannot build corrector for vertical coordinate {self} with "
f"corrector selector {config}."
)
if isinstance(config, CorrectorSelector):
config_instance = dacite.from_dict(
data_class=AtmosphereCorrectorConfig,
data=config.config,
config=dacite.Config(strict=True),
)
else:
config_instance = config
return AtmosphereCorrector(
config=config_instance,
gridded_operations=gridded_operations,
vertical_coordinate=self,
timestep=timestep,
)

def build_derive_function(
self,
timestep: timedelta,
Expand Down Expand Up @@ -392,30 +366,6 @@ def __len__(self):
"""The number of vertical layer interfaces."""
return len(self.idepth)

def build_corrector(
self,
config: AtmosphereCorrectorConfig | CorrectorSelector,
gridded_operations: GriddedOperations,
timestep: timedelta,
) -> OceanCorrector:
if isinstance(config, AtmosphereCorrectorConfig):
raise ValueError(
"Cannot build corrector for depth coordinate with an "
"AtmosphereCorrectorConfig."
)
elif config.type != "ocean_corrector":
raise ValueError(
f"Cannot build corrector for vertical coordinate {self} with "
f"corrector selector {config}."
)
config_instance = OceanCorrectorConfig.from_state(config.config)
return OceanCorrector(
config=config_instance,
gridded_operations=gridded_operations,
vertical_coordinate=self,
timestep=timestep,
)

def build_derive_function(
self,
timestep: timedelta,
Expand Down Expand Up @@ -521,56 +471,6 @@ def __repr__(self) -> str:
def __len__(self) -> int:
return 0

def build_corrector(
self,
config: AtmosphereCorrectorConfig | CorrectorSelector,
gridded_operations: GriddedOperations,
timestep: timedelta,
) -> CorrectorABC:
if isinstance(config, AtmosphereCorrectorConfig):
return AtmosphereCorrector(
config=config,
gridded_operations=gridded_operations,
vertical_coordinate=None,
timestep=timestep,
)
if config.type == "atmosphere_corrector":
config_instance = dacite.from_dict(
data_class=AtmosphereCorrectorConfig,
data=config.config,
config=dacite.Config(strict=True),
)
return AtmosphereCorrector(
config=config_instance,
gridded_operations=gridded_operations,
vertical_coordinate=None,
timestep=timestep,
)
elif config.type == "ocean_corrector":
config_instance = OceanCorrectorConfig.from_state(config.config)
return OceanCorrector(
config=config_instance,
gridded_operations=gridded_operations,
vertical_coordinate=None,
timestep=timestep,
)
elif config.type == "ice_corrector":
config_instance = dacite.from_dict(
data_class=IceCorrectorConfig,
data=config.config,
config=dacite.Config(strict=True),
)
return IceCorrector(
config=config_instance,
gridded_operations=gridded_operations,
timestep=timestep,
)
else:
raise ValueError(
f"Invalid corrector type: {config.type}. "
"Must be either 'atmosphere_corrector' or 'ocean_corrector'."
)

def build_derive_function(
self,
timestep: timedelta,
Expand Down
52 changes: 50 additions & 2 deletions fme/core/corrector/atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
compute_layer_thickness,
)
from fme.core.constants import GRAVITY, SPECIFIC_HEAT_OF_DRY_AIR_CONST_VOLUME
from fme.core.corrector.registry import CorrectorABC
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
from fme.core.corrector.utils import force_positive
from fme.core.gridded_ops import GriddedOperations
from fme.core.ocean import Ocean, OceanConfig
from fme.core.registry.corrector import CorrectorSelector
from fme.core.typing_ import TensorDict, TensorMapping

Expand All @@ -40,7 +41,7 @@ class EnergyBudgetConfig:

@CorrectorSelector.register("atmosphere_corrector")
@dataclasses.dataclass
class AtmosphereCorrectorConfig:
class AtmosphereCorrectorConfig(CorrectorConfigABC):
r"""
Configuration for the post-step state corrector.

Expand Down Expand Up @@ -116,6 +117,10 @@ class AtmosphereCorrectorConfig:
total_energy_budget_correction: If not None, force the generated data to
conserve an idealized version of total energy using the provided
configuration.
ocean: If provided, ocean configuration for SST prescribing or slab ocean
modeling. When set, the corrector's ``input_names`` and
``next_step_input_names`` will include the ocean's forcing variable
names.
"""

conserve_dry_air: bool = False
Expand All @@ -131,13 +136,46 @@ class AtmosphereCorrectorConfig:
) = None
force_positive_names: list[str] = dataclasses.field(default_factory=list)
total_energy_budget_correction: EnergyBudgetConfig | None = None
ocean: OceanConfig | None = None

@property
def input_names(self) -> list[str]:
if self.ocean is None:
return []
return self.ocean.forcing_names

@property
def next_step_input_names(self) -> list[str]:
if self.ocean is None:
return []
return self.ocean.forcing_names

@classmethod
def from_state(cls, state: Mapping[str, Any]) -> "AtmosphereCorrectorConfig":
return dacite.from_dict(
data_class=cls, data=state, config=dacite.Config(strict=True)
)

def get_corrector(
self,
gridded_operations: GriddedOperations,
vertical_coordinate: Any | None,
timestep: datetime.timedelta,
) -> "AtmosphereCorrector":
if vertical_coordinate and not isinstance(
vertical_coordinate, HasAtmosphereVerticalIntegral
):
raise ValueError(
"Cannot build AtmosphereCorrector with vertical "
f"coordinate {vertical_coordinate}."
)
return AtmosphereCorrector(
self,
gridded_operations,
vertical_coordinate,
timestep,
)


class AtmosphereCorrector(CorrectorABC):
def __init__(
Expand All @@ -157,6 +195,14 @@ def __init__(
else:
self._dry_air_precision = torch.float64

self._ocean: Ocean | None = None
if config.ocean is not None:
self._ocean = Ocean(config=config.ocean, timestep=timestep)

@property
def ocean(self) -> Ocean | None:
return self._ocean

def __call__(
self,
input_data: TensorMapping,
Expand Down Expand Up @@ -226,6 +272,8 @@ def __call__(
method=self._config.total_energy_budget_correction.method,
unaccounted_heating=self._config.total_energy_budget_correction.constant_unaccounted_heating,
)
if self._ocean is not None:
gen_data = self._ocean(input_data, gen_data, forcing_data)
return gen_data


Expand Down
25 changes: 22 additions & 3 deletions fme/core/corrector/ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dacite
import torch

from fme.core.corrector.registry import CorrectorABC
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
from fme.core.gridded_ops import GriddedOperations
from fme.core.registry.corrector import CorrectorSelector
from fme.core.typing_ import TensorDict, TensorMapping
Expand Down Expand Up @@ -185,16 +185,35 @@ def __call__(

@CorrectorSelector.register("ice_corrector")
@dataclasses.dataclass
class IceCorrectorConfig:
# Correctors here. Can add more as needed
class IceCorrectorConfig(CorrectorConfigABC):
budget_correction: IceBudgetCorrectionConfig | None = None

@property
def input_names(self) -> list[str]:
return []

@property
def next_step_input_names(self) -> list[str]:
return []

@classmethod
def from_state(cls, state: Mapping[str, Any]) -> "IceCorrectorConfig":
return dacite.from_dict(
data_class=cls, data=state, config=dacite.Config(strict=True)
)

def get_corrector(
self,
gridded_operations: GriddedOperations,
vertical_coordinate: Any | None, # ignored
timestep: datetime.timedelta,
) -> "IceCorrector":
return IceCorrector(
self,
gridded_operations,
timestep,
)


class IceCorrector(CorrectorABC):
"""
Expand Down
Loading
Loading