Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
328c978
Allow stochastic atmos training in CoupledStepper
jpdunc23 Jan 21, 2026
eb19d3d
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Jan 21, 2026
8533d85
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Feb 21, 2026
81f7329
Fix test
jpdunc23 Feb 21, 2026
2584ab1
Add `optimize_last_step_only` to `LossContributionsConfig`
jpdunc23 Feb 22, 2026
b11d5af
Randomly sample coupled `LossContributions.n_steps`
jpdunc23 Feb 23, 2026
e5131f4
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Mar 25, 2026
604d22c
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Apr 2, 2026
7c9e35c
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 2, 2026
ed07188
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 2, 2026
1f39fb5
Allow for ensemble ocean training
jpdunc23 Apr 2, 2026
106c8fa
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 2, 2026
827dc04
De-duplicate gen step processing in train_on_batch
jpdunc23 Apr 2, 2026
2256741
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 2, 2026
1ceb8d6
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 6, 2026
138e82e
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 6, 2026
10702a3
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 6, 2026
4f9585e
Add `CoupledTrainStepper._accumulate_loss()`
jpdunc23 Apr 7, 2026
cec0022
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Apr 8, 2026
dcb0cbd
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 8, 2026
b380c8f
Update tests
jpdunc23 Apr 8, 2026
570fa6f
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 8, 2026
8743338
Address additional review comments
jpdunc23 Apr 8, 2026
76c13e1
Fix unhelpful docstring
jpdunc23 Apr 8, 2026
c97c3cf
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 8, 2026
c992681
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 8, 2026
30ffd90
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 9, 2026
3722c55
Merge branch 'coupled-ft-with-stochastic-atmos' into feature/coupled-…
jpdunc23 Apr 9, 2026
6e2aa96
Merge branch 'feature/coupled-optimize-last-step-only' into feature/c…
jpdunc23 Apr 9, 2026
d7e12a5
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 9, 2026
c6d7f35
Merge branch 'coupled-ft-with-stochastic-atmos' into feature/coupled-…
jpdunc23 Apr 9, 2026
ab0cc95
Merge branch 'feature/coupled-optimize-last-step-only' into feature/c…
jpdunc23 Apr 9, 2026
d70af7d
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-opti…
jpdunc23 Apr 9, 2026
ccc4086
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 9, 2026
3e676ef
Merge branch 'feature/coupled-loss-contrib-random-n_steps' of github.…
jpdunc23 Apr 9, 2026
f14e2e3
Merge branch 'main' into feature/coupled-optimize-last-step-only
jpdunc23 Apr 13, 2026
475046c
Merge branch 'feature/coupled-optimize-last-step-only' into feature/c…
jpdunc23 Apr 13, 2026
82fa43f
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-opti…
jpdunc23 Apr 15, 2026
d5c00da
Add test with `use_gradient_accumulation` and `optimize_last_step_only`
jpdunc23 Apr 15, 2026
710e0ce
Assert gen_step realm and step
jpdunc23 Apr 15, 2026
2576d7c
Move `n_coupled_steps` to `CoupledTrainStepperConfig`
jpdunc23 Apr 15, 2026
5fc4bed
Update baseline configs
jpdunc23 Apr 15, 2026
05b4518
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 15, 2026
c175be3
Merge branch 'feature/coupled-loss-contrib-random-n_steps' of github.…
jpdunc23 Apr 15, 2026
78f7108
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-loss…
jpdunc23 Apr 16, 2026
64d0c4a
Merge branch 'main' into feature/coupled-loss-contrib-random-n_steps
jpdunc23 Apr 20, 2026
e7d411f
Merge branch 'main' into feature/coupled-loss-contrib-random-n_steps
jpdunc23 Apr 21, 2026
ac15e54
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-loss…
jpdunc23 Apr 22, 2026
6807733
Address review comments
jpdunc23 Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions fme/coupled/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from fme.ace.stepper.time_length_probabilities import TimeLengthProbabilities
from fme.core.device import get_device
from fme.core.loss import StepLoss
from fme.core.typing_ import TensorDict, TensorMapping
Expand All @@ -28,6 +29,14 @@ class StepLossABC(abc.ABC):
@abc.abstractmethod
def effective_loss_scaling(self) -> TensorDict: ...

def sample_n_steps(self) -> None:
"""Sample a new effective n_steps for the current batch.

No-op by default; override in subclasses that support stochastic
n_steps via ``TimeLengthProbabilities``.
"""
pass

@abc.abstractmethod
def step_is_optimized(self, step: int) -> bool:
"""Returns True if the given step should contribute to the loss.
Expand All @@ -51,8 +60,10 @@ class LossContributionsConfig:
Configuration for loss contributions.

Parameters:
n_steps: (optional) The number of consecutive steps contributing to the loss,
starting from the first.
n_steps: The number of consecutive steps contributing to the loss,
starting from the first. Can be a float (defaults to ``inf`` for all
steps) or a ``TimeLengthProbabilities`` for stochastic per-batch
sampling.
weight: (optional) Weight applied to each step loss for the given realm.
Each step contributes equally to the total loss.
optimize_last_step_only: If True, only the last step within the training
Expand All @@ -62,7 +73,7 @@ class LossContributionsConfig:

"""

n_steps: float = float("inf")
n_steps: TimeLengthProbabilities | float = float("inf")
weight: float = 1.0
optimize_last_step_only: bool = False

Expand All @@ -72,7 +83,9 @@ def build(
time_dim: int,
max_n_steps: int,
) -> StepLossABC:
if self.n_steps == 0 or self.weight == 0.0:
if self.weight == 0.0:
return NullLossContributions(loss_obj)
if isinstance(self.n_steps, int | float) and self.n_steps == 0:
return NullLossContributions(loss_obj)
return LossContributions(
n_steps=self.n_steps,
Expand Down Expand Up @@ -112,20 +125,29 @@ def __call__(
class LossContributions(StepLossABC):
def __init__(
self,
n_steps: float,
n_steps: TimeLengthProbabilities | float,
weight: float,
optimize_last_step_only: bool,
loss_obj: StepLoss,
time_dim: int,
max_n_steps: int,
):
self._loss = loss_obj
self._n_steps = n_steps
if isinstance(n_steps, TimeLengthProbabilities):
self._n_steps_sampler: TimeLengthProbabilities | None = n_steps
self._n_steps: float = float(n_steps.max_n_forward_steps)
else:
self._n_steps_sampler = None
self._n_steps = n_steps
self._weight = weight
self._optimize_last_step_only = optimize_last_step_only
self._time_dim = time_dim
self._max_n_steps = max_n_steps

def sample_n_steps(self) -> None:
if self._n_steps_sampler is not None:
self._n_steps = float(self._n_steps_sampler.sample())

@property
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling
Expand Down
5 changes: 5 additions & 0 deletions fme/coupled/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,10 @@ def effective_loss_scaling(self) -> CoupledTensorMapping:
atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling,
)

def sample_n_steps(self) -> None:
for loss_obj in self._loss_objs.values():
loss_obj.sample_n_steps()

def step_is_optimized(
self,
realm: Literal["ocean", "atmosphere"],
Expand Down Expand Up @@ -1731,6 +1735,7 @@ def train_on_batch(
)

metrics = ComponentStepMetrics()
self._loss.sample_n_steps()
optimization.set_mode(self.modules)
with optimization.autocast():
output_list = self._accumulate_loss(
Expand Down
141 changes: 140 additions & 1 deletion fme/coupled/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections.abc import Callable, Generator
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock

import pytest
import torch

from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthProbability,
)
from fme.core.loss import StepLoss
from fme.core.typing_ import EnsembleTensorDict, TensorMapping

Expand Down Expand Up @@ -254,6 +258,141 @@ def test_step_is_optimized_last_step_only_weight_zero():
assert not loss.step_is_optimized(0)


def test_stochastic_n_steps_sample_changes_step_is_optimized():
sampler = TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=2, probability=1.0),
]
)
config = LossContributionsConfig(n_steps=sampler)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=4)
# before sampling, _n_steps is max_n_forward_steps = 2
assert loss.step_is_optimized(0)
assert loss.step_is_optimized(1)
assert not loss.step_is_optimized(2)

# after sampling (deterministic: always 2), same behavior
loss.sample_n_steps()
assert loss.step_is_optimized(0)
assert loss.step_is_optimized(1)
assert not loss.step_is_optimized(2)


def test_stochastic_n_steps_deterministic_outcome():
sampler = TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=3, probability=1.0),
]
)
config = LossContributionsConfig(n_steps=sampler)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=4)
loss.sample_n_steps()
assert loss.step_is_optimized(0)
assert loss.step_is_optimized(1)
assert loss.step_is_optimized(2)
assert not loss.step_is_optimized(3)


def test_stochastic_n_steps_samples_vary():
"""With multiple outcomes, repeated sampling should eventually produce
different effective n_steps values."""
sampler = TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=1, probability=0.5),
TimeLengthProbability(steps=4, probability=0.5),
]
)
config = LossContributionsConfig(n_steps=sampler)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=5)
seen_optimized_step_3 = False
seen_not_optimized_step_1 = False
for _ in range(20): # about 1 in a million prob of test failure
loss.sample_n_steps()
if loss.step_is_optimized(3):
seen_optimized_step_3 = True
if not loss.step_is_optimized(1):
seen_not_optimized_step_1 = True
if seen_optimized_step_3 and seen_not_optimized_step_1:
break
assert seen_optimized_step_3, "should sometimes sample n_steps=4"
assert seen_not_optimized_step_1, "should sometimes sample n_steps=1"


class TestOptimizeLastStepOnlyStochastic:
def _build(self, sampler, max_n_steps=6):
config = LossContributionsConfig(n_steps=sampler, optimize_last_step_only=True)
return config.build(
loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=max_n_steps
)

def _sampler(self, outcomes):
return TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=s, probability=p) for s, p in outcomes
]
)

def test_before_sampling(self):
sampler = self._sampler([(2, 0.5), (5, 0.5)])
loss = self._build(sampler, max_n_steps=6)
# _n_steps = max_n_forward_steps = 5; last optimized = min(5,6)-1 = 4
for step in range(6):
if step == 4:
assert loss.step_is_optimized(step), f"step {step} should be optimized"
else:
assert not loss.step_is_optimized(
step
), f"step {step} should not be optimized"

def test_deterministic_sample(self):
sampler = self._sampler([(3, 1.0)])
loss = self._build(sampler, max_n_steps=6)
loss.sample_n_steps()
# _n_steps = 3; last optimized = min(3,6)-1 = 2
for step in range(6):
if step == 2:
assert loss.step_is_optimized(step), f"step {step} should be optimized"
else:
assert not loss.step_is_optimized(
step
), f"step {step} should not be optimized"

def test_varying_samples(self):
sampler = self._sampler([(2, 0.5), (5, 0.5)])
loss = self._build(sampler, max_n_steps=6)
seen_step_1 = False # min(2,6)-1 = 1
seen_step_4 = False # min(5,6)-1 = 4
for _ in range(20): # about 1 in a million prob of test failure
loss.sample_n_steps()
if loss.step_is_optimized(1) and not loss.step_is_optimized(4):
seen_step_1 = True
if loss.step_is_optimized(4) and not loss.step_is_optimized(1):
seen_step_4 = True
if seen_step_1 and seen_step_4:
break
assert seen_step_1, "should sometimes optimize only step 1 (n_steps=2)"
assert seen_step_4, "should sometimes optimize only step 4 (n_steps=5)"


def test_sample_n_steps_noop_for_float_config():
config = LossContributionsConfig(n_steps=5.0)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=5)
loss.sample_n_steps()
assert loss.step_is_optimized(4)
assert not loss.step_is_optimized(5)


def test_coupled_stepper_train_loss_sample_n_steps_delegates():
ocean_loss = MagicMock(spec=StepLossABC)
atmos_loss = MagicMock(spec=StepLossABC)
coupled_loss = CoupledStepperTrainLoss(
ocean_loss=ocean_loss, atmosphere_loss=atmos_loss
)
coupled_loss.sample_n_steps()
ocean_loss.sample_n_steps.assert_called_once()
atmos_loss.sample_n_steps.assert_called_once()


@pytest.mark.parametrize("ocean_config_kwargs", [{"n_steps": 0}, {"weight": 0.0}])
def test_null_loss_contributions(steps_thru_atmos_7, ocean_config_kwargs):
# test LossContributionsConfig with n_steps = 0
Expand Down
50 changes: 50 additions & 0 deletions fme/coupled/test_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,3 +1760,53 @@ def test_train_on_batch_optimize_last_step_only_with_n_steps(
# ocean: step 0 is optimized (n_steps=1)
expected_calls = atmos_n_steps + ocean_n_steps
assert len(optimization.accumulate_loss.call_args_list) == expected_calls


def test_train_on_batch_stochastic_n_steps():
from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthProbability,
)

torch.manual_seed(0)
n_forward_times_ocean = 2
n_forward_times_atmosphere = 4

# Deterministic sampler: atmosphere always samples n_steps=2,
# ocean always samples n_steps=1.
atmos_sampler = TimeLengthProbabilities(
outcomes=[TimeLengthProbability(steps=2, probability=1.0)]
)
ocean_sampler = TimeLengthProbabilities(
outcomes=[TimeLengthProbability(steps=1, probability=1.0)]
)
train_stepper_config = CoupledTrainStepperConfig(
n_coupled_steps=1,
ocean=ComponentTrainingConfig(
loss=StepLossConfig(type="MSE"),
loss_contributions=LossContributionsConfig(n_steps=ocean_sampler),
),
atmosphere=ComponentTrainingConfig(
loss=StepLossConfig(type="MSE"),
loss_contributions=LossContributionsConfig(n_steps=atmos_sampler),
),
)
train_stepper, coupled_data, _, _ = get_train_stepper_and_batch(
train_stepper_config=train_stepper_config,
ocean_in_names=["sst", "mask_0"],
ocean_out_names=["sst"],
atmosphere_in_names=["surface_temperature", "ocean_fraction"],
atmosphere_out_names=["surface_temperature"],
n_forward_times_ocean=n_forward_times_ocean,
n_forward_times_atmosphere=n_forward_times_atmosphere,
n_samples=3,
)
optimization = Mock(wraps=NullOptimization())
train_stepper.train_on_batch(
data=coupled_data.data,
optimization=optimization,
)
# atmos: n_steps=2, so steps 0 and 1 are optimized (out of 4 total)
# ocean: n_steps=1, so step 0 is optimized (out of 2 total)
expected_calls = 2 + 1
assert len(optimization.accumulate_loss.call_args_list) == expected_calls
Loading