From 4da9aeef8746e83e4d13b135a3503db2ca7fdd6d Mon Sep 17 00:00:00 2001 From: peterdschwartz Date: Tue, 10 Mar 2026 13:33:52 -0400 Subject: [PATCH 1/4] Add test for backward step spatial parallelism use graph-aware all_reduce inside spatial mean --- .../distributed/model_torch_distributed.py | 3 +- .../parallel_tests/test_backward_step.py | 140 ++++++++++++++++++ scripts/testing/test_spatial.sh | 17 +++ 3 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 fme/core/distributed/parallel_tests/test_backward_step.py create mode 100755 scripts/testing/test_spatial.sh diff --git a/fme/core/distributed/model_torch_distributed.py b/fme/core/distributed/model_torch_distributed.py index 731cdf2dc..7f84e96b4 100644 --- a/fme/core/distributed/model_torch_distributed.py +++ b/fme/core/distributed/model_torch_distributed.py @@ -23,6 +23,7 @@ import torch import torch.distributed +import torch.distributed.nn.functional as dist_nn_f import torch.nn as nn import torch_harmonics.distributed as thd from torch.nn import SyncBatchNorm @@ -331,7 +332,7 @@ def barrier(self): def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: if self._h_size > 1 or self._w_size > 1: - torch.distributed.all_reduce(tensor, group=self._spatial_group) + return dist_nn_f.all_reduce(tensor, group=self._spatial_group) return tensor def weighted_mean( diff --git a/fme/core/distributed/parallel_tests/test_backward_step.py b/fme/core/distributed/parallel_tests/test_backward_step.py new file mode 100644 index 000000000..dd8467128 --- /dev/null +++ b/fme/core/distributed/parallel_tests/test_backward_step.py @@ -0,0 +1,140 @@ +import numpy as np +import pytest +import torch +from torch import nn + +import fme +from fme.core.distributed.distributed import Distributed +from fme.core.distributed.model_torch_distributed import ModelTorchDistributed +from fme.core.gridded_ops import LatLonOperations +from fme.core.optimization import OptimizationConfig +from fme.core.typing_ import TensorDict + + +class TinyConvNet(nn.Module): + """ + Very small conv net that operates on [batch, channels, nlat, nlon]. + This is just to ensure gradients propagate through a nontrivial model. + """ + + def __init__(self, n_channels: int = 2): + super().__init__() + self.conv1 = nn.Conv2d(n_channels, 4, kernel_size=3, padding=1) + self.act = nn.GELU() + self.conv2 = nn.Conv2d(4, n_channels, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv2(self.act(self.conv1(x))) + + +def _build_latlon_ops(img_shape: tuple[int, int]) -> LatLonOperations: + """ + Build LatLonOperations with simple area weights. + In spatial-parallel mode, this will slice per-rank tiles internally. + """ + nlat, nlon = img_shape + # Use cos(lat) weights (approx) just to be realistic; could also use ones. + lat = torch.linspace(-np.pi / 2, np.pi / 2, nlat, device="cpu") + area_weights = torch.cos(lat).clamp_min(1e-3).unsqueeze(-1).expand(nlat, nlon) + return LatLonOperations(area_weights=area_weights) + + +def _build_model_and_optimizer( + img_shape: tuple[int, int], +) -> tuple[nn.Module, torch.optim.Optimizer, LatLonOperations]: + """ + Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and + a simple optimizer. Also returns LatLonOperations for computing a global loss. + """ + dist = Distributed.get_instance() + assert isinstance(dist._distributed, ModelTorchDistributed) + + model = TinyConvNet(n_channels=2).to(fme.get_device()) + # Wrap with DDP over the data group only; spatial model parallelism is + # handled by the model/layers and the backend. + wrapped_model = dist._distributed.wrap_module(model) + + # Simple Adam optimizer via OptimizationConfig, to go through the same codepath + # as real training. + opt_config = OptimizationConfig( + optimizer_type="Adam", + lr=1e-3, + enable_automatic_mixed_precision=False, + ) + optimization = opt_config.build( + modules=torch.nn.ModuleList([wrapped_model]), + max_epochs=1, + ) + + gridded_ops = _build_latlon_ops(img_shape) + return wrapped_model, optimization, gridded_ops + + +@pytest.mark.parametrize("img_shape", [(16, 32)]) +@pytest.mark.parallel +def test_spatial_parallel_backward_step(img_shape): + """ + Test: run forward + backward + optimizer step under + ModelTorchDistributed with spatial parallelism. + + Asserts: + - Loss is finite. + - All data-parallel ranks see the same loss. + - Parameter gradients are finite and data-parallel-consistent. + """ + dist = Distributed.get_instance() + if not isinstance(dist._distributed, ModelTorchDistributed): + pytest.skip("ModelTorchDistributed backend is required for this test") + + torch.manual_seed(0) + + model, optimization, gridded_ops = _build_model_and_optimizer(img_shape) + + batch_size = 4 + n_channels = 2 + nlat, nlon = img_shape + + # Global tensors + x_global = torch.randn(batch_size, n_channels, nlat, nlon, device=fme.get_device()) + y_global = torch.randn_like(x_global) + + global_inputs: TensorDict = {"x": x_global, "y": y_global} + local_inputs = dist.scatter_spatial(global_inputs, img_shape=(nlat, nlon)) + + x_local = local_inputs["x"] + y_local = local_inputs["y"] + + # Forward pass + loss + model.train() + optimization.optimizer.zero_grad() + + with optimization.autocast(): + y_pred_local = model(x_local) + + # Compute a global, area-weighted MSE over [batch, channels, lat, lon], + mse = (y_pred_local - y_local) ** 2 + mse_spatial = gridded_ops.area_weighted_mean(mse) + loss = mse_spatial.mean() + + # Backward + optimizer step + optimization.accumulate_loss(loss) + loss_before_step = optimization.get_accumulated_loss().detach().clone() + optimization.step_weights() + + # 1) Loss finite and the same on all data-parallel ranks. + assert torch.isfinite(loss_before_step), "Loss is not finite on this rank" + + # Reduce mean loss across data group and broadcast to root for inspection. + # ModelTorchDistributed.reduce_mean reduces over data group only. + loss_reduced = dist.reduce_mean(loss_before_step.detach().clone()) + if dist.is_root(): + assert torch.isfinite(loss_reduced), "Reduced loss is not finite" + + # 2) Gradients finite and consistent across data-parallel ranks. + # For a DDP-wrapped model, parameters are identical across data group, + # so their gradients should also be identical after backward. + for param in model.parameters(): + if not param.requires_grad: + continue + if param.grad is not None: + assert torch.isfinite(param.grad).all(), "Non-finite gradient detected" diff --git a/scripts/testing/test_spatial.sh b/scripts/testing/test_spatial.sh new file mode 100755 index 000000000..676fec76f --- /dev/null +++ b/scripts/testing/test_spatial.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +H=${FME_DISTRIBUTED_H:-2} +W=${FME_DISTRIBUTED_W:-2} +NP=$((H * W)) + +export FME_DISTRIBUTED_BACKEND=model +export FME_DISTRIBUTED_H=$H +export FME_DISTRIBUTED_W=$W + +# torchrun --standalone --nnodes=1 --nproc_per_node=$NP \ +# -m pytest fme/core/distributed/parallel_tests/test_spatial.py "$@" + +torchrun --standalone --nnodes=1 --nproc_per_node=$NP \ + -m pytest fme/core/distributed/parallel_tests/test_backward_step.py::test_spatial_parallel_backward_step "$@" + From 08647d83a6a0b3e561d8e882e4c4525b7592e896 Mon Sep 17 00:00:00 2001 From: mahf708 Date: Wed, 18 Mar 2026 20:50:57 -0700 Subject: [PATCH 2/4] add backward pass for spatial prallellism --- fme/ace/stepper/test_single_module_csfno.py | 245 ++++++++++++++++++ .../csfno_stepper_predict_regression.pt | Bin 0 -> 12162 bytes ...csfno_stepper_train_on_batch_regression.pt | Bin 0 -> 64819 bytes ...n_on_batch_with_optimization_regression.pt | Bin 0 -> 65279 bytes .../distributed/model_torch_distributed.py | 82 +++++- 5 files changed, 322 insertions(+), 5 deletions(-) create mode 100644 fme/ace/stepper/test_single_module_csfno.py create mode 100644 fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt create mode 100644 fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt create mode 100644 fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt diff --git a/fme/ace/stepper/test_single_module_csfno.py b/fme/ace/stepper/test_single_module_csfno.py new file mode 100644 index 000000000..25c64d6a1 --- /dev/null +++ b/fme/ace/stepper/test_single_module_csfno.py @@ -0,0 +1,245 @@ +""" +Parallel regression tests for the SingleModuleStepper with NoiseConditionedSFNO. + +These tests verify that the forward pass and loss computation produce identical +results regardless of spatial decomposition (nproc=1 vs model-parallel). +""" + +import dataclasses +import datetime +import os +from collections.abc import Mapping + +import numpy as np +import pytest +import torch +import xarray as xr + +from fme.ace.data_loading.batch_data import BatchData +from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder +from fme.ace.stepper.single_module import ( + StepperConfig, + TrainOutput, + TrainStepper, + TrainStepperConfig, +) +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates +from fme.core.dataset_info import DatasetInfo +from fme.core.device import get_device +from fme.core.distributed.distributed import Distributed +from fme.core.loss import StepLossConfig +from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.optimization import NullOptimization, OptimizationConfig +from fme.core.registry.module import ModuleSelector +from fme.core.step import SingleModuleStepConfig, StepSelector +from fme.core.testing.regression import validate_tensor_dict +from fme.core.typing_ import EnsembleTensorDict + +DIR = os.path.abspath(os.path.dirname(__file__)) +TIMESTEP = datetime.timedelta(hours=6) + + +def get_dataset_info( + img_shape=(5, 5), +) -> DatasetInfo: + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[-2]), + lon=torch.zeros(img_shape[-1]), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) + return DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + ) + + +def _get_train_stepper( + stepper_config: StepperConfig, + dataset_info: DatasetInfo, + **train_config_kwargs, +) -> TrainStepper: + train_config = TrainStepperConfig(**train_config_kwargs) + return train_config.get_train_stepper(stepper_config, dataset_info) + + +def get_regression_stepper_and_data() -> ( + tuple[TrainStepper, BatchData, tuple[int, int]] +): + in_names = ["a", "b"] + out_names = ["b", "c"] + n_forward_steps = 2 + n_samples = 3 + img_shape = (9, 18) + device = get_device() + + all_names = list(set(in_names + out_names)) + + loss = StepLossConfig(type="AreaWeightedMSE") + + config = StepperConfig( + step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="NoiseConditionedSFNO", + config=dataclasses.asdict( + NoiseConditionedSFNOBuilder( + embed_dim=16, + num_layers=2, + noise_embed_dim=16, + noise_type="isotropic", + ) + ), + ), + in_names=in_names, + out_names=out_names, + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={n: 0.1 for n in all_names}, + stds={n: 1.1 for n in all_names}, + ), + ), + ocean=None, + ) + ), + ), + ) + + dataset_info = get_dataset_info(img_shape=img_shape) + train_stepper = _get_train_stepper(config, dataset_info, loss=loss) + data = BatchData.new_on_device( + data={ + "a": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), + "b": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), + "c": torch.randn(n_samples, n_forward_steps + 1, *img_shape).to(device), + }, + time=xr.DataArray( + np.zeros((n_samples, n_forward_steps + 1)), + dims=["sample", "time"], + ), + labels=None, + epoch=0, + horizontal_dims=["lat", "lon"], + ) + data = data.scatter_spatial(img_shape) + return train_stepper, data, img_shape + + +def flatten_dict( + d: Mapping[str, Mapping[str, torch.Tensor]], +) -> dict[str, torch.Tensor]: + return_dict = {} + for k, v in d.items(): + for k2, v2 in v.items(): + return_dict[f"{k}.{k2}"] = v2 + return return_dict + + +def _get_train_output_tensor_dict(data: TrainOutput) -> dict[str, torch.Tensor]: + return_dict = {} + for k, v in data.metrics.items(): + return_dict[f"metrics.{k}"] = v + for k, v in data.gen_data.items(): + return_dict[f"gen_data.{k}"] = v + for k, v in data.target_data.items(): + assert v.shape[1] == 1 + return_dict[f"target_data.{k}"] = v + return return_dict + + +def get_train_outputs_tensor_dict( + step_1: TrainOutput, step_2: TrainOutput +) -> dict[str, torch.Tensor]: + return flatten_dict( + { + "step_1": _get_train_output_tensor_dict(step_1), + "step_2": _get_train_output_tensor_dict(step_2), + } + ) + + +@pytest.mark.parallel +def test_stepper_train_on_batch_regression(): + torch.manual_seed(0) + train_stepper, data, img_shape = get_regression_stepper_and_data() + optimization = NullOptimization() + result1 = train_stepper.train_on_batch(data, optimization) + result2 = train_stepper.train_on_batch(data, optimization) + dist = Distributed.get_instance() + for result in [result1, result2]: + result.gen_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.gen_data), img_shape) + ) + result.target_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.target_data), img_shape) + ) + output_dict = get_train_outputs_tensor_dict(result1, result2) + validate_tensor_dict( + output_dict, + os.path.join( + DIR, + "testdata/csfno_stepper_train_on_batch_regression.pt", + ), + atol=1e-4, + rtol=1e-4, + ) + + +@pytest.mark.parallel +def test_stepper_train_on_batch_with_optimization_regression(): + torch.manual_seed(0) + train_stepper, data, img_shape = get_regression_stepper_and_data() + optimization = OptimizationConfig( + optimizer_type="Adam", + lr=0.0001, + ).build(train_stepper.modules, max_epochs=1) + result1 = train_stepper.train_on_batch(data, optimization) + result2 = train_stepper.train_on_batch(data, optimization) + dist = Distributed.get_instance() + for result in [result1, result2]: + result.gen_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.gen_data), img_shape) + ) + result.target_data = EnsembleTensorDict( + dist.gather_spatial(dict(result.target_data), img_shape) + ) + output_dict = get_train_outputs_tensor_dict(result1, result2) + validate_tensor_dict( + output_dict, + os.path.join( + DIR, + "testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt", + ), + atol=1e-2, + rtol=1e-2, + ) + + +@pytest.mark.parallel +def test_stepper_predict_regression(): + torch.manual_seed(0) + train_stepper, data, img_shape = get_regression_stepper_and_data() + stepper = train_stepper._stepper + initial_condition = data.get_start( + prognostic_names=["b"], + n_ic_timesteps=1, + ) + output, next_state = stepper.predict( + initial_condition, data, compute_derived_variables=True + ) + dist = Distributed.get_instance() + output_data = dist.gather_spatial(dict(output.data), img_shape) + next_state_data = dist.gather_spatial( + dict(next_state.as_batch_data().data), img_shape + ) + output_dict = flatten_dict({"output": output_data, "next_state": next_state_data}) + validate_tensor_dict( + output_dict, + os.path.join(DIR, "testdata/csfno_stepper_predict_regression.pt"), + atol=1e-4, + rtol=1e-4, + ) diff --git a/fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt b/fme/ace/stepper/testdata/csfno_stepper_predict_regression.pt new file mode 100644 index 0000000000000000000000000000000000000000..c1ebe892522c3469893d243c29d072e510ec7e64 GIT binary patch literal 12162 zcmeHtd0dWL+xFc&(JU2`CY9u_G~U;FY9LV*l9a8PP*jS9REm^R8KOyvBqW(q*IF`^ z29zlYX%j-olp^}>zP+F4-Fv^^^Vq-d`~G3$L&bGxO%-rF;3)4X5FRR#=;ya8C@8=$oZr9IlFniN z0sL-H<@Z}MR8ntls8o39u(_eb|7fvasB|FTpg^I(-w2IDWdeUAlKnH0+#iYL69b3v z6v9Il=Z21OAI7(EWk6&&f3$tV1Nft$^e0T?(2;!1=X?W|f5#N~&3F}EzU9B$uKF9~ zsPNFyb3?~O7`ix$i66MCFya^fFbq|fe>vG)j>yj^+i=OMuwb8XuQdT-Lz8cBE5PwE z8sa$j?&0R4kshl5n!|*n_BS|wpHVNLpk+%}1_!JRx3^XOH%rTG@{j>j)`q?mHQ4GHj6MU>A!RR6UN<2l+tw_#{k zP2$?yj{CU*TDd)}eZ3HVSJcMp4c9Bx{XX;5`yN8oRyC~h7Q~AhQ!r9W5OrpcL+heF zaCzS)a_L1TNWAt2x3PglUY-Tq)d;sU+aT=29uS^>oh@$O%al4TC&9NB@m7Tl-c?(O zKT=C!d7C&YnM}mV4UYI?;c>DkN{sgCo@4r76p^INonR4N1h0d-$k#<0C>-6xtbZ{J z-lpl|+b?^d+GZZ_TrCGbA8jF=hU%(c-D)64DaI;1*#I+i^sphR4gB*iFh|O+lg|68 z%*t)q5Wn#{`{Rohj^U((c3mOywpdJcl*=K%U_I}7sU&oMf6qKuNCNe93wSz?lH|3u zD_jv&!lb#6AkO(LB-}d*`Ni_Qr+0sW=H@zH;JIVKifw^gLxuWM6DA1$=*o59+< z7Qmt_4!B3L7&!H_*-vBr@W!W)V7){d0(M@6gC?=eLyIK#Q==d)Kf<4*Wf??7E)zzJ zZ)RO%rLZ)o6Xr@7QtuxYwDRr}5?2&W%5=sfDQ_oRkH(X?GGnpZ;REPK?S#Af$B5#K z5qQ2-3U5q_gO?-cpvx6YdThlLHotFFRdsbJPcHp9$S8#X^VJ5=)V9I=Z51Hgc^b^L zoS6~S6sL9jJfd;UzI0Sny ztKk7J4(y zComrmK+JeYo`IDdVwVP7@tTOQ*Z_R1CWwnCE>g!CAP!p8ydkI@>E5SpR!rTS|J;Vt8o9}{b{Sk2W zPGr8H&4)4P-T|}dBseuMg8}VQW(W5sOzLg`Zd)DC`cgELCxsZPIu&}Jeq$UCmqPR2 zsZ35tC%dKBo0R=PMz~uSC7rGV6&nY22WLQi=L~q9ZVbH^78p^q88#e!$n=d3We4Ue z;_!vnAVX6gV?HZEv57h>^6ooKJ1T`;&L6qb>c|++YcM_u(=vGAM^1wEEMt7qzY;e3%tx`Hkri^9?z98%1Rn3#rz||~5XIbm1E8hQ$=LcEVQA22`26T7ymDOuW@Fkw zt!xkuuR8##w&$RSyA5JmA3;}-DXJxl!gayv(39~9dZPvKbl5_;k%{1bnS-0)GVJT! z1WqLj!DYA|*iKjhMbkxa)&euEeK-dNLlWp(i_^dzDyTUu7pCScMa7jaP&al58&a9UM0iibfzOF( zHE|A!NiKwoI~-8b*_*0!jd|*g_EozRvWSMg4#o?xaN358E-_rcdz z<}pM%`5|nDl=b8&9`Aux}9PNukakqaxD$!+lz5>#0VU5am*vlmrpyXzY; zKA=G&AB_NZvo&f=HGm531MHN@9#S9W1*M_3r0K3Eet0k(s`RIWwfG=R6*I&;hVPg~ zvCr5eIz1qpo(7pU3$f$q;i_k0eMGkW6r+4`8?(7B816B?kn~)I#5cYM7mdffW71-D z@U195Y;0lOHyNVrIAyecDhM4}wd}e`Z#-q3PLv9)$+g#&EXMg^w@w>0JN?R&&d`OL zq8>0aG9Uw~w}EqdiDii6NXA-Y4Kqd35Yry2vEnWAZ1;vAtkIn+IBa5$eoMZRx{*5U z=YnFwU0%VC^^MpV_swiiWCL??qacPXh=ma%BXH7=PjGzXY^bcZK$$fJ#+XIG znG^b{Z~AD0eqk)+DTfl#flQcF_#O^Er~uPX zw;(BeJ6zboK~Z)CSpTrV)A21#XnhyRn2rUfSRK3u*5Kr>2od!%m{b3Poj*SYL|0D& z|91z#Ltqw(5ml;eF5SwEGMbFSzB&-2`ktqx`V?XWe_>X6X8?D~Cs;?Tp&(WOJ2u2Y zzS>l5o>eXy6qY)4IGW? z4%)nRE5*@ZF$*ipD!@RLg`gxU%zt)|iNAY-0UJeZ2)+aUJ8L0T?LL%UsArdKivr!$ zEXHl*ekl4OLkenLpgT$zW1elc>MtGxrNdb`5VsZPq-6l7qmBu_Cx`~(&e-a368iY} zsM?YgkXKKJw6?8`;j^`H)youR^$xbW>_MPg5VMA@g%6kxHSxpIcuzNU zSIFX&QJQe1GXaIyB9_k(A>+f;K;*+G(j_SmMO|yjM(1;+Q>li%vag**7*)_v1Kdar@Z;y$0sR^l1>j>K7DB_ySe_qp;Q31Qt(Hg~Wadnm;Romp1!J z6>sh_(r8{sew>H_C#i7QaC#g#cpicHm{iEpS0JLVc7yUWPvkkZ5!w6IRaIrZBqn`9 z<=OeFc-nafGx6(5IPvWQjOdI2({Fc4jIyCsl(aB5=~`mmu<3B4xtxC|5l8Rnl~u2A z?gq2_GT8e{9wU8<5z$_{!%gkLq3YpM|L{2+KUd=>>uAx3Z;Zl_V6aA!uAdHl=v}aoy*yh4om#8 zOB6O}>0!oLB{cB$0`3}BHs3UkzKwFjH=gAn?AQwh2X6u4dBC8a2`m?Ff@X&q_<$tC z{P>+j|3M9C)C$l~@1#+Uo`yFsbg^XSXIQZ58&S-xVP@TWL}W)w!#RyYvRBiGRTog; zefgyyJ}S0D#^muZU1l*@YE=-z zPF(Kh_oIRlPVI-Fz}YBOEss{a;voNI7VBG41m0>Q*yV7ZHPak{F{QJsmb=G*K~*2f z*mKZ9lgpOKZvw7!0=wX18dPM?hB8w@Ht}i=C_e24t;0{4q&p>0w@wQpU{T!}YaV0q|maQCT$xC~MJvRMOst%Sfu zbScwtI~eBCG&ueJ6EiJDAGRJ>z{K=POkDG3&~kVTUCj!x?6?>%Q4|6R9t-k)g<$Ei z6SnZ?0mobyaz$3dmM}H^6xzZXT1>$8B31AxbTJ>SD2&E|5I+Jys9rFa$?Eq+2v!ASoa z40-C4#d-(0GW@_0zy3QMu6Qd{ zaLt)y>6i4YWIy~^A4q0)ijeB_n=4D4Rne`@kRxPRLfQBo+Mq!*Skc)v!i}GEpq` zb~*^7R~1rezazBb!*H}LzeTUgT%@rDy0B^gHxNIV3BzW+W78H%QOB$-xZ~9Ur8AW9 ziB>8q_g-Z_7G4A`|09rA;|KdU-lG<^4v?zqPiD2Nqx#7+?BuXiI4s&hIL)^$eQO$N z?@|e{e)|wEA2>rgB;9$VUsO{6##9j3u4O!LG?2F`bBMz0L9%761o7N=m0VK44DWAm zhg^RzymELdatjjJbvFGVem{~aY?=X4ZnMbNvu)5=b(?gY^Qht!UbV9G5XUhS+lYGg zPWFWK38K1YI+-}(2^fmSVEbn`8dWa`-Yb@}i}lCQ@6M^TZ_*UTZtq>I#QFInErr{NVkDyGKZ+zR|iy? zz`>X=c2??3;vpzVkKP#=PSxx(DEC@KGR2)?TcxbIM2Fntw>2i5H%(1jWxzu;W_)rk0way2wyEN>YOBoH> zI$Ck{42?{kk23E)adLq;maI`Co>5=O!H)HlZL}#nZ#7#dFJhlx&?KCmco1F`}GZrJ~wHa8i(!-hEFTwtf8*LYvfw^Is zu&7@KA2d!!i2`+W$c&^lC(>z;`3T(CTS6R#XQJNry>#`fRBF4*23w4Ei5bs^E?;U# z-&-RpzMYQSYeKQ5dKyV9-cdE`*);UsH4Bd@43M)8u~4kE2o^n(q~{L|Lx<}gbo$p6 z^7>;NNKP9@Yi9;ybG0Fge^FrW3?E5NH>?Dv_5+)crGah-@56fDHhN}f1MGfNLT~8= z1AYDx3P02nd;XcS|JwuN{!NqG_L)(^cw5Xq^%7j33d3TD4Rm_wDD>^vMpc>^oXe^Z z*|!t%#py1%Q54JX*BrWK>1dj!^b0e8|6`EZ=Y)wFs`%udHxadP!Qsk>s7>$*+I;OB z!!;Vm2y2Pbkinrbj-jukk3hopY_e{B1?)O_gH_(lKlgswLknN!kgnEww6#=;TDZ4B zTxwW+x ze`us3y)Su%n~qW;i$&k}4&Bbdqy0<3dt($$z&&6*eKavW_Kk24 z7+{*PLOrnZjV^JbFmv}AK#dx3Ztgg*U=Hvc%bex+hgx#43C6$w~MA8LY z!!?*2mnPD$PiNtm$QtPHd;p6~@3JYvHT2P{PCC8x1s#Y=fQ?}iSaE+J1A?E3-{5t4 zJ1~cy9_NB0!*gMm!f>h)Fojr4$uRlRbxh|o6O>KwXVfpvsv7m<6Lk{4!`!oL0G%1P zK_D~_(h{#S0|^Eo>1l~ktK_gfeJoM7E(PyRxgfZR(w$Aa;i6O+cwZA{A_e3aS1~iJ zNxBDGCCQA3<}9-5s2Uat#X&7|jjqYL4r?OJ$c={Wkj!;}mT6}|WbGhHSIHy6&%e=4 zR~o>s<^r)*O~nOGcK#?0B`0Xx$7qMok|9$Ykxx-JsK_9+4MZD%o+ zp0Vl0yF|HR^e;oo@aZN25cnsYi<_9{8$j0cED1@9c~NhkNWxdbRL=o z*^yA~L~Oh3Kp5XnX3?1-T=PpA-C4Mw*pITKv7hgfX;co+maN4qn`_wTUfMM2IR_s_ zT0>E@FeZk_gJHV{UYTx$I}T*fm0N6(`>49=yOK6$No)tv=K`1=%fH8Oz6mS5M?vH3 z7^-d31AZA3u@Ym!N;ru=w(o{GkL@6slS>7sGH7m_0X13`@Mhj&dOIS99$2A=-d@w_ zY@ewJHe0c|ArNl4l!8T$G8XBY(P!e7Fzwh|Vmtc?7^v~%8LAf;EJ{SF_33D)rb%Tj zldGzCyH+(wdSda+NTeCsxPA3Uw%Im@ml+_5;)$({Z_0QYbRrZ-=LMqcrZ{$&Qwg(I zI+xnHeJ4RBpICu8AK;DBX6l_bhR*OWrt^Al6V+Z}NO}-Wo-ZDT&quCAiNoiJajPh) zUwR5kw=2+-UuJ)a2!I<&j7iuYNF(}nLhG8MfXM$rfI4?7R{F?oSvKJO9RWP za#%h4dTSAlnioi0pDmt&} zE}DxN!(K%q+dLlG`(ha4*1>Ejn@dG|!-;&31nK&Ek>93Gh7gJDm~SlRcp*_H?9G)p2H1#iT{k&(xVbg&1y zE7)Ov+<2<6KL$-+S5uvJd|rcNIE?2-o9}Nzm3mb?Q0+*KqirEh&xNSJiDH-5Ceg6! zWuY&<-P)mRq8ky90^p zN`=`ylPR)1nnq_(rrKhNdx^Tf}j zc)TKgS~3>z$!FjVNmp7r>k)`|rO}q1B_JLWjdg2n>5S$k=scdp+-rLTU)Kyn4zNuhFQO5y*SC^lZ4v4wbev9LYUo;_ee|2_b#m0Wi}cyXF?GilLgEDn zs?oCo&2rY0k&BXO^-DL5-(w51np|{0BtyX|o-95$jRx}b5#JhLfdr?8esnqwfrnjL zUla$Oswf=)A)e;o;0ja0eO3+6Hg^#DH(#tv2ts@v8^lT^EYB{t#cs2>t z@}*r*l)+`w;J*r8#oJY)xFHJOm&rQb4 zL6W3p>9Xd$>-5e-MC;ESs%X{B?w3-oLgGA(XH#j{v}yFli4tm`6pL+JRB%uv ziLNtkCtI==(e+s(eOZ!9)|w<>{8Vw2OGv;Urkb>`zlMH)yo<(DE~Ns>*c8h5M`l{m zBNN1MhEONDGkGJrwENR$y*XGAB!*g(H0Z|j--%hT4C*hg21Snz)cMzV(n$B9Tyq*- zd*dU_6*y1lXBShKQQeT*X@kQ|jnLC_3$30RjcV_dNcPB0pr5;hdMcPeK;>qtY%Pxm z+gq8B&A(v$>WOr#ZYp~{{>eVnip$l`4epCv4C{8g9OhHkyf{h+i_u@^f4QYcuIfvy1exNftfp5l4?* zA4A1YsnAGgYs{ChV6t`tWiEqmgr#uJKs~o9*^f^I|c)@U0E<9wySbH|n(G zfFx8NHNvMyo|CeXyQ%RAL%LJq3K6*Zo}4``hZnNbU~)_z>r!Zh;fpEFH9G|Z^U~Q@ z(i8D$odW*q)e6g&8o^+71sRMvNN&rdp#No8>iOa!7|F$>@`n-B`qm4pA^$yW>goso zKmB*({|Ep5KY7pp3>4L=|6G)|HS$Kr^NYVcND^2zDjhc{_8lu zE=T_7asGcw9;N>h=r@rZx{=(>TRQ*e`a^S$%Fme$|F`&WbwJoMz5+7JXGn$Wwank% zcEn$Ze@`PAam`IkjE%TPmS*P0W~OEqrd+O(1((Y;H8GIWyL)FDmV94NqDu4ed6V)5~;!ovo>wy0X`$uIc_fL>-_+rtY!u~~Z`d9Qn%05m0 zgg%Ggxu2u|H?8Sk3I3?*)c+H~t-nL?Pc7+RasQZs8viF;H@;H&r-uKg5B)3XAFsuK z2IVOHT~NN9_s{F%k9+9Ppw$|G7u5K#pnu#hLyPghN*8~Wf7<=uH{)MX|G2!({|l1; zY3_u-f%<>mcrK2jq6WW%sR>Q|C#=NppZ@XG9a5O{OCO(VsPNz8|JUJg)F%!7{)aSL RWJny$k>LLfwg0W{{{q?qz&HQ^ literal 0 HcmV?d00001 diff --git a/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt b/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_regression.pt new file mode 100644 index 0000000000000000000000000000000000000000..8e50478ab933e46db7e4c285281fe16d5d532f4f GIT binary patch literal 64819 zcmcG#3pABo_y6sDs!$|JC=!xT4sp%7HcC2|N5_kDi%{r;Zk|ND*czSr1m>@oJZ_V+W_an5V5b}V$8@xQvFq&5mAjNG+)@qb{vEzDN>hXx1wg_s4c4hi`O+S^ z*e}4$J1jIXC`73j?Hd*tw8T5qe^topVDE5CC83m&k-PjqR8|KsUgj^9{_`(N&Oxge zhtBE!d+(1*LK%0--h0f2vhLEo-+pVtgmP0QLxu9Qr%D(b$NTxM4hr)33+=5^h|EQI|6{((^b6A|!-l1ivh z)m7NvUG2ZGwQ#^+BrJt$e@dwT7m0zPLJe2pAou?Nl33=y%KIN3;Vabqi@;dn-~hSl z#j>G7tzHGAsnSy=dyC@K0inVnM$=zP_I{=wo;u)f?mN{;59v@9O%0Q#bfieMG3x&{a6nUHxCShb|6Y<{$c}`HO`{e*sttjsFQS`8#0L z-vFlZ|1f`asL;$+IOg90e1+zJ0ayzy{t2-BJ7DbJ0OS4yScM9$U4=IP4&W!W{R_ZG zIR2jiyT1eM{|4ax1VE?|U4`UdW6kovofm??U~GjBe-0GKUP|w9{zqFlg$gIQ3Mc+& zB@3bRUzElRC;h49@?VrDhYF{-3a9>S)c(~b;k5rpLaMi8EzE_}|J0cAUo>Wh3TL?r zXa8$pS(^QqVe9%=1qU6Pzc<^wzcriNpCfC2sLjRo+tQ`A>PT zzsoQDoBX0b<-J3NKCZ&W|LS&2v%d~p-@g*Z3jO{`Sn_v*|KA8p|Lm=0p~3)HVc@?b z_zIW*l`u}Y;-7?|zY|vejj-xZ!s<}r8dst4-x2(T!G9%K2}Ax#2>m-D>~DndKM4_` z!nLl#bzx?+rphZsz8F6Hk*tKoA7$3R^))sU;reP~fsvEB#l6Iyq-8gOBHHO?+rvG{wV*6uJ^yATR1yv{|CB<$M$fy zfAFiKUH=G#{=ep1{-5~o)gS(YUw>x)>;E74WB*Tly*(jp^~(+rTN;pCHO}l;raH48 z9!=!e9>l__FHzQ1h0MEg0d`frCNGsw;^qN6+3%N?q5&g^u`CH^W)NJ08O6s47wk`V zx>aH9!lP`;rfK-lViAkB_at{NClCYG>u5P#l3%Rz3RidBB)ruPk{F>Vw!L(ojTslt zPrvTZX5We@ZV&9}mwHq2RCy_4x;BjY4%ov73bqk7mk?sV_5$5}){#8?9mj6AS2OP= z$3?#1HP~=pds5aIM@nu4lA;n<=B`t~K06P4`#sl{S@h#CW_6FtMFhewzj+ zABNG7#U7+tF`T{gwIqQ$bLkfcMV9b13B%OY+3}@ySYz7Fh4hdN)}zS`@3e*_iVbg;rl3NH>P zxVDcJ8qOGog1o&@x&IP>sVyB8-}r!=ZU8@EfCzB+HMo=B2*IEBg3P2kQGWeCu5j9N ze&ro?yj`Y(_jDFucXA;tZ&W~yv7>QZ?Nn@AaEkYel&2j=rQDCUJU$_P7uasfgEv9# z{7@df^v|0kc7Buh@wc?7WZdDMa5oH&eXo_iu##o!s z2ult!?s!oh-}*3_TNR%HF&pbd-QVnRs6;A^sLADhY!^{O&0@&TS#STcPzhRpf8<`O zC4gS(e0#&GO8gtaT(~TyfeEgUA$sO{h`oOXvhxSnKfkvS>g#Ll14>VVNInj357Xc+ z6%|2s*=hUtYIFFJ({wO<_X7L3YxCfa&wbuBD-EIsIMSS#ulQ~US$_EK#rzDdD3A?Q zK%=@5RB5pR3cgPkInT`jugfmDS3MsjuFepB(e=X{Up|9isWSNQx(G+cZs8u;CWyXV zlcL4Pd;2IbjhF484%!M^MRT|G!NSZ|a8)#;KHavo{N7SNI&TwSWH=J};%0tZK@9(1 zMHkPDCsUj{O*=fK)v8g!5%&i=_KkT}}_FTMuAm=YlfZpGWrDB-a9$1MJ`t0sOi zJTJPTGk~Ud)q-=`QP@|hgNMB(aDS*4ykFi7m!>OXwy6vl8t76Bu??>Hevfawxd4Md zeTMYPj}T&91KVW{L{Dyb0{8haM2(zkZ)(>cv3)pP_8yI|MgI6+M+&`E2VrqX0DpFU z3<#okKvdw+iWc>OqPD;xFx*TAJR1oJ`c#%bICBiz1#h5LQ5^-&uQ;QCqY&i%g*#t* z4h$=QgW>A2_&#+HwT-wAP z?1Q*gYXWpU|HZi+D}?%e6S&MZt)jS2AHJv?IhnUcs5Gq(sQfUfIWifpwoZm8sg}@b zYl~rdTVX@NBksGdP}JqBj{O!~hcx{G81+>H^2h3lWIz0di3NSIedZ^a>8XfkuH1tO zBXYQJVTSm?CL8X4d&?P+17I*dkz3@}1lg{o6)kBa?Q1R1z{E6r@Z-;bV7etf`>_f( zE}n<-L4zCBa0T3_f9J;CIm8)U*#I6Jl3>xHI+)q&1U|9$_D`RS;7N-dX6^3+1GQGp zaft;AgTBJo#|7|u?h3FT+6X#DJ#cK@VMul?g%0z0hXJCs7XSq$WW}+GFV4 zB#GyO7r>2l1ouh_+zgelzjHH8JFx&}^>YHpQ7a&Ck}OW2Z;jQDW}{^Pb(* z1j9X*$qZL>@ui{j*{W;FUFF3VtX%)6zIj5T#-6K=H9w6=>FJ=~8h zpHs}5gYD?Tl>^u$^F{1#a|ut$9#VAk3{7l}C2M?@SmXFoCXP)dHJkd1zAre(w8!KV zht$o)`EWGbrc}l5_v~X~UlVAhgCf$Ti7W$-l8X&Zbm8(krv36WdF=fTpDt7ppNwl_ zJ{2nB)k|EM)?!A7>EGdH_EoT)=oXQKlqu0{)n{$gh>Y}{j)%g1>FG&Z*+{7=d{gNT ztUkJ(Tyw~!JN&cg{13B9M!A?2pQuB(hV^Wy<}E7R)l7w%56QQ|C+MVmy5iZn1ge(R zvg?VH@!YSelv$ZEFUQ@~kAKXm@4Lu8J=iAVrs@-|F9vW?JdhMjmEy{rq*-fyU(&d8 z4VxDsrZY<))9w>MM%%8%qW<;trJ@~)8GDlNY97R9#aWQO^V8Yf4huHkScxf*ct*wr zcF{Y{PBc0|mpxUUOopXdl9tsK?EKqh%+rNuixOtANGoX;eCYyH=!qo*w<@sti6t1T z=8ppp_9dBR=g9c-Xu2Tv1`+SDX56+7#6=uMwO%TTS6!B53(^y*{`Cla{w)FftnLx9 zd)92tXEoX-r7QN1Yb1|#exUj%Il<7Yk#t?EBUAo3leU#vux#%nvhTA2>+zbxhD0XN zftx3@kJ7rVWJn@0X|TcYu1k>gDFKWG3N&EnVASuM1rsutq57&>P@}t3w5B|b3-g(X zU0=7O-RRkTRAMe%-06Z!GkvI@xuw0{HRp;wu?P6!&W0ExDS~qj<~S>AsK`gj92c`1 z{zFF+zwcKpNTfJjuymTlD;*z<(~Q$NjSsWIZG|)@O+5zBjvMiUU`1~BcuUw&I0r)a z_D8Kj+VJ&mHY853g}#Gxxv9UK`H1@i(CGFH7?N^@SDzw=13otRrs@ITaz}s1&sNm)cNCKNpz=a!SAWNy zzc`pS&wU4$hllfP9}fi4Rsjy5U{tmJkk* z_-i?!Ew-7&12Cc(0VbbU)pz3_q{(SgHSb2kotC2#eIU0!J*8Rb&!vxFk&WAOFE^u`k zNg!y*D);&N3dVLFgciF*==(YW5>1BS_1IF#{*=OPE~|xM16KkH`s1y|`e2f~1rBNo zdAY81n4S9(jx?2l)tB3l5V`}{P6?C~Z2&>HEuM>M;DlG(LB&cJrfo6A>mYz>?&=VB zRRuGzwu$D=ivqcjF|g#rVepWg&PT~CKuc7rEqNEmbt9{bI zoP2?GR17&=B(Y^fG_1Wek1Jm-3sGG{@ZISKE_+VFomIPg?|BPxZfWo~E)HZePQd!@ zHLyNW1$D2?2YbTlh+CZ~JnjYP{-PpcVWq{ zYDm_32u19wXlZ;T7$qOz+y)(jylxdfr+OB=jWoij7hCOq_r4z@wP%$<0CM%{SzoQ*TOGTXOMg= zg$MdZz$Z+Fs+fLgx%Vx+EmOrn`)D>q^(o2GiiCcxGtgsS2yfrNl&zx$)aZgeTM>VS zz1qK$Ex)TuF89%7>sCFW_cs=BYZYe_Ri!YdWY@``(CT9PI(ZnU{)s4!UdHa1>QPbU zZo2MU2`OK;7>aiuCSx^?n7cTEEj>JtL>yiMW;$Q+&{8{k<>4+&zIK4Tlo~3KUtEv- zzO|!z&}F*Bw;t<-eZ||B4rVE_%53byRJLWw7S`!5MK<~Ru{ks5u=O7`F)y_b87Py& zTFm9?ov=HobnHC4pO?hMMK{>v4t)|SwT!$v{(Z?}OooY*w~tDmDx$ z#Sp7_BA;bOxQV%>Y{XF9c&LjGja)7I=7RKO@NLq*Z3~`0H6HiXhp~GDtNHz|DWV+f zIc%$$HW@q5g|E@AW~Ie}T+c;yR${q>rR+>5HnV!y{@`fVT)K%_o)5;R1JY#Tse`Bz z;YFT|w8HPx%b4NtX8OJSBsr=p!(MOR%beeth-L0yA#tOPF|$*Fg};&&SXs89OOgUx za$qOvyt0-ZmFY{`b7k0t_50XTQ!lb-(O`Ba{tdaMS&Kh2)>7r_C-g(Oidbn$CODqA zB+Dxp363LVk;f%E^UHcX5;c-YsjniR1FT3|ekOB%mW->8Tx6AzjrfR#)5`LCSbuaY zxxeBPn(1p%;meEc`sIgoB6u^YdvZ+uvk_@)6l21%H*~`nNs*n+b?p0IO+5WY3beLb zlX>CosO)dUem`}hiCGKixRyi2FZ&%W^i6`}7sj(dl|oEh*h$mhK4jXYnmNqaP6p@I z(vJ?dlnl{_{60}5%w1FnAG%M0f%y+nmvJrEHfbV+u3m`JvEQI#$q;;CITjX;(Sq$i6lwPK zF#D7l&noO)Px9AnYWVKcQ82AfC~P=43|u^qLrhdM957Mi|%Kdq5heJHM| zDC*>+Qso0s0qch<*!g+@9$3^LB@EjxWbeBmnsy@!9xN{wEgPJ{7aj3|Fh4OA>ksD5 z=LNyLev`p$gED_;d?m~;Qo^I%b-dlr1^nlOe(-hkQRrIK!Pz>fp~ihrm|T2;AE~g7 z@4HME&CC6`DTmCc+10f)ZTtdER2z!vb-v)eCJMV`d*IN>AAD{fX;eM7nJaUA2OTH; zxT4nOq7yCSv3s{1Y%nm!G+hle_4Nkxa4k``RWyAc>4xt-i$P{;C*&Nt4ZOVv^catY z<#IQn-eod2@rf`mW*2YLR0YGUCFz$B%BVxn!MipiJTc`f%-{TrS5L3vrr&??QHO z-X-XE4ZIBOgZI6&;dJHO-g=aAGRZ$6C}0Nm5f4DS-O-SJ=77kzEDwBiWU<}lg2-Bb zAVw9=s95eE1*R3>LB&}DUG&XGCkAW=^O>=t`4>~5EPVzPSxJeuU#SB1=bd11>=~DE z_XO0eGXRP1nH7DfdF!H!@<{18fM8Y<7)4$ggGOBa%#1~Qe}_VfLuG{yI}8(BZdxhDAFEmi;Z6wpMqiJHT~D47iMOLyvXUz^tFc+9{vm%gbR< zIeHb$`tXqR8S)#(@!k;pWDqWqz7H)Ai|t1Q?S`mt%4q9!06ex_fsLW=AagK5bjs;A zRFs>8z1u1Hm7fpVilLyrA_~5bJ;1%E>tUY(!O&elME%WvgU|XU6$blNF*&H$zMH)< z+o&PXJ-t6(a?b_HEK{hzeaT+$il_ZFb46S|Z@=?Ztv39wHQU)b|3f6~{Y(7as79>UeB;Y4H&J;rT{2fOnMqrB z*uOrhiO=J=)BY`8EcKlXU9mz|Tyb0(9|=;KgJlN$Rm@n?tz+E6zR}Fn@*`Uv_#B_z z>7iG2b=a_uVD{W1jg-9Tg6HM~Sm0bg7HY7II(@i|w?1l#$9*nlW97@(uW`RIr-Y-j zHx-%RiedbGZaSIkr$n+&ts+|!OR+jYME8`5S)gDklm9u0?bj@&`|r7sMd8w9z4lGxaW_e3~WgTz=2 z7Pkx+5iwWSML+F3?ay@oPt{T(Fbi=tqn=M{EhO*Q?M+sYy@ zIFM@JrNq)CjG8SCA#)ZR;^FihmOp1FxjpS9+hV~)Yt^%eiot0XYvjphm$u`+@YTe_ z_BL@{u#HN4UBr7o9MEneWf8>@q%HqDxY%k6UdxA&l-k?ezDE&sO{zQ19=?TpFzBo( z^Y%16v-l?OA~ObuEY9QghT5P>-%z}DE*4c~^kH>uBHiv84Q>+3AbuQ#`QtCr^3{&G zq&LqlI@u_yo*ITjjD2D5=`?CJ;srHYD@W}F5xjzdCZ-4Prb|~y(;Zh$>5kurB#*}U)WlHVx|_lHJb5>oe!w+jXoZgt)?B9=W(YTuJB43 z{rOPiMv-q;F;%@knp*v6fX%n(LVJdSsNVYr49t{dLw~eFhI%<)*1KLy%-nBR8?HjR zND=j!b_BFn=Tc?A$x!-xW-|RghHy^%?(tJ= z!XN|g!oHeckh$`SD29dN((mJ;O~VN_2efd7DT`pr_7GTkC7+iYq|cw8)IPEQK41)Hu-Ld6_CbV*-Jt4^oV4x53vzw-n?Rb~nr@7PB}UMEw>)ehKT zX~bLGJJ99JoajdZqWb$uxT8vl4dRJ>O8(A@AulGP@9yb%T&;^gU%LhJHN3#zcemdlYPDe%aMhngu?L2u+mVN`-aej| z?5c%5?@rL$hAV--{0zCDuJX>kG0dS~O}zUred_q#no7kuV#e85FzdMtEOOaECkcn3 zZ%aHKe3L_0(O_Qn{b+1E*A6%Gw)EC(HjP`RO;a=$a`O&70hRsJaC@2-KD+P3%h}Gt zews(A!^+dN{`xP@++r9fV<1P@^!(ArQS{BG;}Bby!LM6i2D^{k5NU4hjk6Z+rMa&& z`Su5M=z~HHYU|zr(OY6^Q~NRO(bb0bxBcmXyeViXuZZ(X()b5z1~{XS4xKu;1uivD zqsD>7{Mu`ooY?$;-KKF8lsrns-8)V3TQ zN^rB`KJ?CbPAlK*kyk1OWd7_pBp!v#y5I^M_O*{d6hDo=er`tgY;5A|AKKCTTPgSR z<50o6-61$RqK3v6-R8)HxpbFt2@CnLl|)7NXXyq}Y%14zkUVc8CGCiB9yck2?zTQD1+xiJA6gRMw zZ{L&0KcdN~^WS=N_Q#oCQ#i3)lf{;oWU=4Sm)JBKVMbCs(>qzmFH`C#9zAjwOY#Q1yI3JHSkF(f1&BwUIrj9w)|Dd|N7LaeI`-p+ql58pK#?y{s zc3!dykFL7ReAaAWP7c)humc4i$Jp-1Ix@^b#7^9M!}Ply z@qN>8oEs`lEvV}i1>EQHBlu*;RL9CS%{C10?UsJV_k7#CYWnQJ#x+MM;$h1ar7fkpYDM-(XUDO zgmdKU)e|IX)j@V!ZxI{X5>1Gu9)6Vg!e6gikN| zz3mZ2-8QXw`Xrqnx89%9l_yX=!;88^^#_TxesHv68()40`74#@SWeRBtxo>p z%@3Pmiqv~<+LCgr{3?ll`sv6=4dgg&Qp|Ne9EB2HN#*XljKF2dFF00J4C+d6MfcN9 zAxGmfwLf!@AJ_d_G-cEns_w0eiYv2t??px!v-r2j$YB-t`OpVC%*_gd@6Le}#2kHU9;)AVY5)SUl#0E>qm$C6JF3qh0ESl!?nH`i>j$VIK4~LD~5D`q0?mUa`&BT z!Eo{&kQ5$-l*_>8_i5;9{R(@VPF-t(EN0 z&6T&ts)YMsa3YcO(4Wq4F3`a|>1e3tuG8?$ItUN5=5N&QfJAc_XqZ?6vJpLe>fnR? z%9p?B=F7FeZnALzDV0 z*bNI}o`dS`!*JR(7j*5Cc#pU*ARZ=1gB)`}DQ^J3Qa_8Y*ww=iJktl|`Z+@C;~%~G z3462-a^i&}wqxTx7oPKN<-AIQFnnPV-IaTYcOK$Iw|u?FPo(|v{D}y>ytPX7(t89= zcqxI8*9su7UIw>^#(-J#aJ)Rp0(TxxqpRW^(EPEu;h$g>XfY1+wqu|5 zsc5aEPgTb!R*3h^t*BM<#QZ61k*1Bn9U-4Z^^Q^Y>HbQnu>Aq&n>3OJofe|@!2q1Q zIa;)P+6iu-au#)R`^^WP_#%>={R!S_Y^6RaL+Rut`E*X_9bT(b1`?V!@h=zk#g~It zq2jSp-tvJQe|6beDBPh&&wO)$<^Br%)|^uyyf7L1w~KiVx2^QC?^(KUEzjMYFaY!B zDf1E?w``UL6jROMtD-mC@@V9o0Q%s?0y^#aM3fD`2a7Xv>D)JZSg4~&Kij%HtZh7H&h4E0yaE{@NLd+ z#bo$Z*zF)CtkuQmO6hX!)h?&0r5N3FN7{BtS44csYm~9$w zN!r)Nv)Mo8+2FXNyuRrtdeHO*tZ11_{DZs6*^~lu+5avH+0lWsiu{QAg%3E`kHd`m zB=%|UR2J^LkWF2Fn4R{V%|?~)r!PP7FGq zlLsD2 zG0%lo?9A3|YVLiB7(QCY3$_@t->IAE6o~<>|N2>^T)dw|(4AzHfi4+3c>sB|=^AU$ zKhJ`zO6c{BslDreESqDYCb0abMM{-#Vx_`vQn#oEM@BKWTe_9@Oc2tu=8xG2PLfK@ z9Lr{$G3DoH2NH=vr`VLYli~cRZA4X78UswGbLpScapCN5EMe?Q!LRKPiFZFq7CvYv zc@_QzN387E05wUjfj`D4ke*?ZUlx7|ckp+ho=5g%)wicUVE%yyiDz^wNW z#^kWxJ;zE-(mYE|kZbXiy!uwh^;{hyuJ4+_bm)Gf`_P+AIrD?)y^9nrt4^T7;^n;h zDJeE-oC0ROEaG!Qlb}s56>g$%<;O|JtCODw0t~TySwm_?a45s zV;n`1Jx!sLDVKAMBN}q(<=4#F{nDbQ?g|Kl6L~NEz-5mJ0@JTrz0WuO@$2ZXeEvvv z`uv10-XD;LH^cDoYc^p`gd?3?e-m0yC2;o}AH&b^zUU!60ACxA z0kOnd|vMUI|qE$_-=oI)bXcIl3QJp)f6mUsO7g2K1hf`gQF!C`vTY&(qF9z_Ga^ zUsM3Yibx#!DTZcu?!bMC+vuhJ4*cy?s<1R;FkA?DLtVbbaM|%EIP*tJwD5xo&h$>; z-P4n~?I+gJ(9CG6ck~cg?>j(cf{P*ga|5hVxo%&P8O>KTZUE!KS3xO#GR+tO@K$3l z7qjUuHz7om-WgEF_ca=c`^I0Rn%BJ1Il-FmkxioT`5e`FHJiG%ti+~{9sCHLOj^8h z1|MMHOP7~s?M-^Rj@{&mE zSTo#x!eQqZH5w7Ti>}`qg^dH&(Aj49`G!}oLBe`Iy*m2<8cbKg!UBEXT;ege;@eG{ zz4sPp=c2@?9552u*w@j!3lIfgC8)Yxz35OMEi6#&UAsq(qjq6od`r8zVS@5fI$t9nPDroA!E&B7KuUKV*{P} zG={%M_hSG06dG~kGq_4#pz|{F>8v4dA-UB7`&wC`=lD1(p0WvbK4|b6gEoUn)>7)J zHWvKLw^B{P06fzCfcsp(5Mx3{(``musDZyaZknn{x0{dPa{WDf^W+(H^TTMAc76)w zb0Vq8coN+w-vC!{_QUV*0^m#9NIbZ4ERFN;f$du(aq!;n{JHpOYVp0AYL%qGEVGkb z{uW;@{q6&P$*Gx8w0E$`+(Q}*-^lZZnRBU{@DMKwT0lc>*3oO$@+j+<2&%nj?j$49 z=_~7t^vT!*^t?wjJy|!DDx4im*Ul7RwxTU(Fe?GP*WaLzD$BUxndfQv;?=O#=@7Na zmq+%i5e`1uPNU!H(U!wXP+nkx&yT<4iw5nXmIKY`F2&2d{5zx7q!XKkt7oiT?kPKezf%`SbUxdhk+qJz4$gGHQKsBgeL!ClB>^v(v{Wp^N(m zYHM-Leu2#!Y7~E;4V-h5#c8F|g6K+;w`wTwoh?N_ok?d}K0@+Xaw=PYO_mvl6|sn@ zW_Gi~3yu3eL$!Tl$ptlMyw>v=bKx~T5O^LJ?7l_~tcS5Bw++Z1S9f;#Xd;%)K1|%Z zW$|^=0eWbfAw9lrKYbs5kCNrbsIo-_8A*h!O|^{ts`tU)gLkkERhOw;S1{eRQIkj% zpW?R-pTbuBoJgu{rjoZ8oS4nk+vNQWOVT`h19^QXirm^cmRl1!PEh&uQ?Kuz04u z~MmpB$uFSdd#*`OfE=>}pWT1iGkoabGe z_R%HPk+k*1CZf3y$hg?IxOvD%(Y5+h{HMtVv~o@Z>zO{Dh4r4<;oD|_`Sb~7mqjug zIWUY3R@p*Cu{E?oc$;4km_+LXwQ-bS9_yGGO2%DEA@iyY5fa>3>gCUr_SL3$r(WX2 z^j9*SS9=Bw5(P-!Bc_d^n7P8H})fo+^Wg4%qn&E((Dxvf8Wg~62WF5lvsT+?i+!Kk5Dn!cWayE- z#Cr2hTJz*9F)ZFYb0akG?Q%H3kNRc0Z3)6kC`3B1VGWbWnTA6G-^3Kz~G zZXX$<89-83GS(to#s;2!M3?p+DVAN(jUwNJXwz{LgU0ETYrk3%vIauH^7}80nb(3xzfWX#@)OuPoqD)!UrBm;-m!(Ff{E6x zP{w^ah5>gg*}ST5Dzk3|X?-Xms7bs;K1_Iu@`6Zqcx5V^c7G_lu_zA{gZ-G+L@B`) z$t)84BMGnT`Gfkb(Il~bB>S%N2oJVT#O~ii$RJLRd2K$4tEQGyub_J3CchlRvIh{; z=e_x&CTVu;Tp976WF@}4U&IefQX{{|eqaYz9A=ldWZ?Ag`s~bscI-cCB(u_`DMC{~cWLT!rG zXiZWCW-PGhKXr%Tkdw1W?uq@>pt_Q*R2)Uh4mvQ0sjp}VsbdaY6>;dJO%Js>vAPYV z@Y5w7Si@yDZC);us?DVnt?n_~?Uz~ds+H`eu{Q~OAWd7yQnqmEX7=cens}H^4m;Dc zo=L^%RaEPZBXedABW{PLlkjpx6|GM+fzPL|L!Q!@%9*6LR7pI(*_SyjF2(V8&XTz~ z$*k?34mopfB*}ZPK(?jWlKoZ$ zliawy1syYN$q&C{*s(O846!+ew!Qi1&e3uL)#d_xAur1&e{Z2H->b6U7Y?D*XMM6n zMoRGER}&fb?l@^&)`xrtp7@9!(kI53=m|3wVzlKUdt^8ZXKcHJDrX(A```tp-;mC> zPuap$UvHv;`9E;}SxbIk@47s_paIUeM$@wiQP^+fT5{E;x2_Ux*qMBljf8jPc;;bp zg)Jw8?5o&~M;at=Xd4)>+)Z{#{h$wi>*L}#C2abWT&ntPAG=j0fe)?Hu*cm-Y!M`2 z%CMIOIoOkD?g_|i8j=9_7&iRPccx) zZ7ufnw+^}LQNlV#EMQ+39VI)v&#{1*Vyb|0;&H1EV#xgIAdp*35>B?meBE~biDdSGy0jyXkWSMD4@zyz2qkSO}!% zm<&l}dSZ{PPpUYLXEKrfuy#zkZFChxDT{{m!tL z(+7x4b_&__5^q$Gzd>GxZDZL>ZbQ#0FKSbD8tZaoV0Y_ytnG~{Cdt%NW}8WuS*Ef< z5ptwGqLyVmn8Yp0&cOLUAF+*3lS!-c54^Bq1TnL5z-tR^*tCOBSoxY(T)6QrdfNFi zm+ThmXZ?z`7ZlNpnOoVz8WXndh^8R7VH{ijrU{0W?4@e1M`-2Bp{yocfvk|bN?X>n zpm^9uGU}!Q)7l=*{PHFEYpWNN$2(+LV9q3vdpAI^+`ob;HuIeLpK$+`FRNG8Qn~8Eggs*84B!CK^jxcisi9FL)^XU z4pZuX7bQ;wqV^qoaL^kru6wB=9^d1|1kU^MeB^|R>id%c8&A^vDGNzMTpX*(Tg^VZ z9)XI8Fsyl)NBq~D5S?eyc&M)(9kxG_(8N1fKQ)So#*QWfLw&&I^d}UPDXcwtZWNrhF4V7deJ3B>Vy$W#8gk<9UOGR)jz5v%345K42Y(>wU zvw%5CWZ@!f__EBNR9v4x((L9Dx86ObhBsC8(m+GOAxC*uIZc9u)ZItR{>tKy&Msu< z6KOW7`4^4u)FFu;c^GZt%Yv6!vVE|ZWEf(9jyzo7b}<+OOMLWB5HT z(l;D;JW6L-$K#k)^-NOp*oU1-e@Zs{-=s!4{lpQkW{?R9k66|kU-rE$o~YDs02%c( zd>nn8-BU3XFEBoX7v>O*d#u53e{x{;vpB)TuFd50oeyZ}@5aEcf*PK>Nlcq81nX2^ zll#B2$?RQ`c)Vdb2{b^`CSGL^78#jC3r?EVR?-FfT zaHO_i&-$b6@rCpBiFQ2eALvZ_4|V1~jy0k2sX1g|IvM5^3>{t)XN<{#rVJK^a7NO=ATns_feO37jawknJv! zVWD-a*qa{%#hVr;vH9os(wDVSIH+|phU|zXv8FcIbg=@@j95$VPA(w%R(UXR&{F34 zOOaer2uCKh2fwNhCbyb;_wS1;*|V9hB7Tf5X-;d0nO|kdxh3jk{`*-d)h7?D`={Y~ z=Yhm<(=J{RHW~-H&Sg9Dq{Khv8psWu0?hLAC%!RptVG6$+9ssZXL7fQSF0~gIlYEO zg$1$2k9jbbImRYCX4qes{6wZ_bTXeC@AwU;KVs3DRQC^v6Ad#t5HO5LJZE{r0p{6>(lxkhx?jaa4{7){rGam1ldMlr?qDSVG1 z$F}O?ei_)|lb z57>0&F7bfgxG&`;H5n31Lk??GU7fddMR6E&+4+?9k7y-}GGDOT^NpDOXb-!$S)b0& zwqvp5I_T1a6?n+~JU!jjLyp{%n4vp)7z+T zsuXgfGbHU*JZmtoV?})p@%M^$x^RXTvDH0+A17a6?zz?UoYWDvYJE8Q9OFs$omMBl zPaaaJJI?;E_Ra(vtFP_sQb7{QghtUC&y-_rK1u4%Tw}?!C^v_c{B%_H}(f zyHjI8c>$kB(-JJ-qz4}kq*IxeU=kOik2m%s_Rn1opTZ|XQkQ_{c0K{)A)8o=5RgStva37R6KG2ihmsLD7qERDb+xvZNpt&P|gO zoS!WUMT1|U*w{jHO1OiO(NYC@pU1Sha}+o(O9o|=YkZq8T0GtAL@X~zC!=Txypf+k zH;tYTis>2jqKPoLoEKvdHF3k!^`O5ho{9$R;FyK&=<@a*|Jmy^w2x>eX$y#e7=wYl zGPT1*d!9NP-tmL0eJb(xgH&jh+6*P5tl|ELkwo)U8Ah*_qb~|Z!4lmFa%XHkiCa8d zaPyl#{pPd{P5T0G=pHG2hBt`T$YeB`{)wtQEWjpNF<#fFZ4l-)2k=@w&8%LBXZ=6Y z1(G|!Dd#4h@M|RDXZ!M;R1RVN$6h#DA}v@MB#UqROAw7Pf-8+8VD8{Xs?L*v6LX~n ztDhS~ywfd|E_4PDw^+XNEipX&_!f0ZZY6u3EG3rKwnSg(2wj(Q0JAo3g_9AF$@(1% z#BJ*^d~!4h%>&2MnmwyXXGRM(ceqW=C%(i{3Rw`N+d%>Xec|+y2B@xY}2_C7Q zrgLjfLS3JQSZl!M^<7S*b6VR_+`pa-k^}7cuE5irV+hX!6T#_KA-WzuMO;ve_Eai-IDC|rPF{nG>^*DBVFidhl!d%og=D?yINH%&1IyQJ;;(cLft6FnWB8d} z%(WUj6)< zR(7glYR`MBApQ!pc$4wM{uDH9i6UONCc~#CV|fK1MEIjK&0sE@&#d}N9PU?znoXMU zg86Jb4sC|Uk!ckN;K10oBv8cxjH9CALR~Ut#(d__>y*YT*Ove&6~igDDhOD8nyr^y z3YB+cK`&$hPTD$_gc_Qm;hN8IM%|V;JU9djT8XsRe<7P`DvG;qnxb6xUQmY|5D^~D z@2J^?$*Bf#Q0WP%YMTHoUQT*vR*~KL`|(y&C|~_vGv7{nBkpiH3hzGYW2V=3BHbTNI` z8@w*kTl(qPG~BuwifqbR@)5t z)>J_X)FQEbcqe~e(=@7)d4N#})`W80+7Hd}23$;HL!h@yhYtLpH}KEf>C5QS0f+Of}>G#Iv< zz283xfjXVxAh$0cS39i27>iu8U1=$J(z|rSLJ4vnrJ>_SD9Kw;K8KH>=sx=tewGcI4-N(tF*MjJKDMs2pH%ez8>~}?M02J#1L1zdU||a zDv>&!MH=oGV6BPeA7!b@PZ^9vlL+MJqr1+mSMY)BG}$e23VFvLYxio zmT5EudaNaHo?j<(mz<|v`Rn=i=M2g50X!TN{TXBHN=Za~GAXcGk2*4kK_JKmP5XNC zd7llcoKQkZ>4W&}!*DP!{6<@)^D)JifsUbpps%Qcfy%}7R#FSoxhMvnducQJQwy1- zOJAXR$pRR3>=v0B-b)*__t2s}srVy)o~0P%@NbBpPx`a?Ip^QP`M3NP{9BC5(`aAs zBzi)3AhfB~^WVJ6rCM`K!0gBevnh-Gac7mCAXm~0mbRq;gxJH5l4uO*KFBmmNx}O5 zhnemDD{ygGCY*}T=Fe?I^7zDBGQlGY4bl`)NQ#s;5E2&T3+9^vd*N<`61e z?g&bQvWfV@!_doUl2doRAUe4W>+OO;^^qzy3;9CZwfB&9Z;DB@+Hm^i}j>4B06T;!<$)g%wvnU^g701$+} z0^SaK@V+(K%wou5l$kCBXGW=zXEVO?^JzHwc5O9YlX%O7-{KRoAqC7Uk$8GO!yCg% zKODt$5+ZAe{@wjydQAe_<}HVYDYt38_hG)Wiv+A#VT~8M1L*t}d6eHdjAVEy@rpb> z>CSd}p7^#hdfL;A6xC>ZPM_c(eru^;m;#g$5s@qsYTM?HI5=ntHFR zz%{WHX185KK4V9|txM(a8yZ5sxKf~fQs{wxS$Ht`d}*^=Et(iig0+vrVUlwcj(Cwq zrRN-?YBx4e-P_^B%dih`fSv{2{5lfi4JX1J@eD{eAIz8~@=3T!3GJ4iM)So-Ky}_p z+;zU4@T8K!a)B#c-qi*5Q&*tz_!f+exsJ)xp2N$SU>r8ziSci+fg<5}(7$MhF9rnQ zjI$P)vLlDoE<1qJvL+J$lj?*Yo(j$jQ|Mk%Yg&D*l6g5L7ETr_mkwGX2_yPHXLT!^hZjM24;V--X6yiKw`MXiKN4rR4B>~RxU#+@hw0F~?f7wR z4dMJ-IRBQvf`5yKL@@Rh>&yBH04^Kd$s{!l<*8?6(8)5hsax|G#?svhC*2nyR5dc^ zD%$y@6ooOe+=J>gOvZ@^Yw1z#Jly#vjMO|F2g_TFvEYm(zGxTXoqwH6?!LZ&Z`GpA zLJrwOS7bVD>=$O1m=#Bd59^}Z`B(UjI|c~&-nFztE(C9ezG6R<4IrJKX&~q^H4$x0 zrrEQ2=sVIK3s3jqmF4Xt?P+#bfVEYte+$U6HT2+)d)4nJ7SHUp$Di`&M~v+m9Z`UKn-Dr< zr7O{&AdFwbqUdY3u5p`q7$^j`lkv*Jf;DkrtT)12x;{4w6x>piPktyc-vU^w~k!W~+*F-;Pt|%l)C`{wi4JAHn3^ zX7G8F6_7R^(9`r_ulEo^eyJiS*wS9J$wRV#mi`1+&b3F@f*?Hsso4DSkk7$eR+YC%u#vo zE=q>iKvmKaG;%OyoTBQ9&QNFEUO1SSKS_yviW0z+Aa&jl-%WrbugHlFmJt3<16p@n z#S6_R$x*>_`et7*Y-r|#wqrEBOk-d~N<8g%pdZO1LolNHZt1m`1L=p%8#wC4U9!<| z5O(f7L9Md^E320it7~UT_l#vYrKK5mnx-OYw4((YeuAR(csN^_!*omdfRolT{;~1i z*d#fgjeFFg`?X$Nb9X2fZZpSw;|g*4MI$KJ08pKO2gYsfMh&4vdelXqcvqW~xtgCa z`I!U=EDylxE^qv(dy=WUwHRi-5)pWA@uj+_Ebz77Xy}@GpPHNL;^QzQ>Soi(?0?Y@ zb2A2$1K}Yc_Dz+yLvtMVf4GvixFGIr_8_Ax-V)s>pZM;K91IGaNUB#pGu_&eKrJtL z!A-5B@N7bVxU!WUuT$f|Bl{CI_FMrmmK}TtHdfWEUjP*Wp%{DcBz(Cg0qqw~(Bp?n zG513{ev4@W-8u6>_x27Tp#(D??gek&OpLjafm!uIbdrw@9LiaZw_IGwJNHA_IBXeY zh|93!w}SC(Nyh|<7Mz)_h>{{5RPvx7m}grM%i;YP5Z9mUG|M)tB2Dia&9)5iXZ98@q~XSAaZ*M*Ro_-hA8KE~j{D}NWuokU%(A{P zZI=yPup3H~7aoBjOC2#@?=6vh97btFIFv^@g2Xa!I97EAE|;%`>E=9J`J(P`+O=Hf@b|@|DDJb3U z$>x`oHGpl^Htdy>Ve>Tg;H}RabX(d?zV#wb%$5RMWHWN;HCwrmof`i2Hhoz#b?N2{XMAYrO4}|9>m!_2yQ+xfsE86B!KkC zsh7mbz0${w*uc-=o0tynjt6nUjZ*v~KL!`a?4d!`+u>x}XZk{U5s{zbNt{Ep@M(e# z&R5nV1F#&H?Gk|dn@*Cyy#`d@mXN~(W}?plCsg&&y zA?(aZ`06$a58SteoR)S{^KLM2nROV1n8m~1IWypRjta$p}a)Ri@uSBlF5xX*_c&A#-h{CkHq)1Me zniUS@DR16~7EgPy`RPD`znGSwJ+cfBjHv{b5k54u=o>LUdIr}EJn383GqCH{KAb(c z7dF-=((;FrIQd~OQ#F1aW*MYm_COu7O>aFEL-!pLJhgiwkg4AfVD3ri>o)=#$J}IkWCr2+ z)PYP1MAN%rgGrW8BCZNxbJI17>DdqYuxH9^vtum`k+)N3_c!yf-9v~s;?`0)v^p90 zZ@x#wmx<%?;@v1W_d7q|w}5^y3B`+7j$oxpHrao}6E117_Djz1nFT$yo~B9WQR_%3n9^U5hqxQNQbNP)GI6H;9h{nKB1ls- z#tU^<$=IlHtSEA!(&=t+Lt_kH8Cpg|0xv`2E)#}{r>~QbX^Hq@#jd}_&)?eZsQNd=&yD^p ze$M%~aQ-cS1^p$@)7KP)Y2a~|MKMx(3mC}Y%7p#3Wi?-gD=bw%_Pc^)QfTt@i zI510vGP9Iub)!1gvpGl-{_3Q0I=NUTC;0lro?3|7;+@6} z;5j)2l&JBaOZ0@tc~GV1ai*imyAi5QPwh*?6$`@}LWJ!hy% zL^0mXyGw^ZHiHjmRmdd31?NJF#mWdW~z#SV_FUtRiwgnHb4LR zlff|C?hO^O?;`mRV<4h$Bnj!5f!Ql|&skqIoieLU|9a+ zH7TPU;9^EJu4>3)ida91O$&CD`!7_%^qVv^ZNANHdZQrNRDY5vw(i5UW9*jf%;&K0 ziz)L|MiXW$W>f9UFVI0_A2Ve1Bs|eqg12Dy44m&fx9)yD(5+JEY6?&GGVUUL<)l7;Ypx#O5 zO7$}Dq*gHcMa|^4QwE+nkV8(h^*sf%x~QdgGS(?45w)a9{@oi1u)xWgM((%ar%dhU z_nqMn1xFkp{BsUT3@Ii-DH5nAs)=QRN+gu^mCn^1gIzijC3^Kzyye%|(2t9cqsi!4 zGF+enUnC+?rr3eMW5NmMNRufx&Qj!UIUz((NJW5N>sKm1n~ks4WJC7FI0!Ci<6Bx< zz?e-_iSQ{il)bSN*9I1&&Vd+sSD($Fx@`@>>QKykxF1~hUE<$fx)jQeSVPSnW9S`H z1o;_RBv+~r?`v=ztqj-U&FtGw7fyD?O!=L(!6JlNEgDP`{O2+5Z*y>-mKKam$tQ12 zk3yG2e~>i{g((x0h(7B-=Ch`T-nuRW@&nUhF)tTqMi)T(_d)3Kssz_OXX`5-ej#Dr z(NNxTfI3xfBgZ0xp?SnJ@+eLl`H8NOjM;3PUtz+`Aux2h7Cch6!V_$*xLn&LW|#-T zjbG|%m?)13FJDQ%&(NmE=R|l`a>2m)w{ZR~e+B;*ozK=}NOK^z3SAk-5|PO)`gK|KUE4V-}TwqNNn*_&v8CY}1_r;~l_ z>(H~$Lnt}aNtar(J_~MVu}^m-sf)aWn#v-$Wbr*hMn};HaS}w`DH=RR@M()zFsQI8 zh{EzJ5Hor#@->8cCTuR%a*u5gXx{*)9TDW*s9L(uKMFq`S%O6y6>&#OF7h=mGjA15 z=?9DL@aCEtTnJ19ANLM2)G8A)7fz!?#Gk;z3=Kiw-2||24V_xKnN}$`keMb0sMxrQ z`J^;JuxL~|1Se!8UKs$PmxEC+(v~#1$MEyEsL{svx_GxC1Ol}u!fBBiL|@7SMb*_w z!mo`&CuIz0bcL^DeRfHC`s@oZg$Y2HSn=KK;1gOj0p z!)nyL>;y`dcWKb)6`*f6jvkHmgmaqC)Tm%T+-n&t2<5S{`iweqG`AjXmiC8TFGI0w zLmXq{XNfKwa!^#hhbic;K0&8$lAN&I6Ha1wMCNIn<=cfdD3m#su{~) zV;N7j9C1aZ<_hxW^*ea1`E@MipVnz)V!K^tUXQRUlSS`dHb-O88L~yKo1&~2 zo*Iw=2bTAP-J-P+r_xOZwI{%%#uQwxJC4o2_)b#4QTS#h&+iNo=V_=FLLb`*Seg<5 z&F2#7q0LV~C~gdV=o5~^#27~Wtty-{tRRQKiecI=c@!L$ME$;cpczTf@%dI5l`#os zw}-(Q%{63c^hFF^KAL7Uio-oe5uViiEZB1B2$h!JO3%!zfpm5~(81=ee_3oomNY#= z)MjhTf_*^N&mOAHcOdirFx;R%)RTZIw}n79xP+WPunyEb;$WEN zeBh}(EmqflkMnPTvwIN*hzfi-P=;0(sn(q>^^YoZd#c=Q38gEN`?unG6Cx`aVv zhcQ;e9$?A6Rgm;J4%$r~lAV|NG|l`9an8|(5u?|@mG7sp?CD1IT@{K+mxOqDcq+=( z??$h;DJWm+N|qF=ps9Ez2D$7craK>Cz=;^tdA9}4r4k6|-@^H~{1yCLYL_&_WuPw%x!p~|otGHy;4_7O_PTxVT4 zEf-2_6DT~qa1&;0ZU+h3#cWQEEm(G3W;6pH!Lj`kyoqut&~J?{oJwH*Zq~gf(%%$V z7R}jGt+BPl`SlK(8KFt0SkJWP zBprmZ)e3YCm`pQJl+C!@5bYX$H z8}=BM;pb*yGG|jPyASS5hM5e2FGq8+R;i8HD25T6yw{{C$ehN+DD&Qw7W1=tY8WN= zf<_qGlAM5J{K{&!wz;<-*q%!v>ESvwJI#t(Dz=x0u?@(=;jQJhpfoqsy;cnmj5b5s$8rz&nZkcpZWA7_FHA*)5XnI%y6i zAF51KpM0T;ycpb{sz*8;HK~@HIb&dOgKS^wMV38!fdN-{;=7GWJ-(J!%&I0!Hb!ET{4C6lZwKuh39LIC1|`poU`u^JnDBZ9q}Qin z#4!&Ll2_mryida|tJ2A{#q(gKOamNhq&V1sfhOORFj%^ff1}T7jBn91s!Z8n#u zrmYugi%u;#+)m>A49OwVN(IapRgwum|M)Z0V_ zN(aw@35N0K#Z?o_*hhHdP&`Tw=%D7!{#c~G8`re&rI$qJ68nUe=;~PrLq}AQ(>b+d z>M&y>E;55=oNc4;l>^Kitb<@fxCN|X_bMuv$AaqKG#WOe6gHKW(Wv)B%%sP~fXI7x zKVKt@jP`;$^mzWrS$xrR(R}o3EW(h z|Hcr#U{P@E-w;1H{7j#f1}qLu=v_P<>&q$zxDUqX92c< z+CG>6&*lIB74rW%ek;ds<@l`}zm?;+a{N|~-^%e@Iesh0Z{_%{9KV(0w{rYej^E1h zTRDC!$8Y8MtsKAgf6s5#6B_z^{8oeU(<~JK8Nijx|L5}m{|fp49KV(0w{rYej^E1h zTRDC!$8Y8MtsK9Vo$(x%JEw{ek;ds<@l`}zx993Z~guFx#6G0 z&pCc8$8Y8MtsK9VW|kqG#+m%w-(&DDR|N*7=E6~69%KdzJB zk4}&CS?%?!dePgaP+?Z_$7%GFvX{8fzkW?TynWZW`n&wsyYg!z5U}p$f4>Qne`-Wl z;x`)c^Y`|3UE%J+UiTGi*SN3s|FxZ+%+n!kuZ4b`20yhE^&9QD1^T;t`!09){Z(G% zl(G#wfP{XW<9?DK%#Nwwc_;Mrf7PxCne$_C{sZlp-_Jeoy#SJ4-Rh;U8-&8ULpC#9y`F+U=Uz)>2b^k|w@PGaG-#>@24~!qjdA7wL2m62Q(@#jK?(2{5|M{r5v+DEX QAt^MN{rd6x|9tI#0PNq*;s5{u literal 0 HcmV?d00001 diff --git a/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt b/fme/ace/stepper/testdata/csfno_stepper_train_on_batch_with_optimization_regression.pt new file mode 100644 index 0000000000000000000000000000000000000000..143c3c078bb9c0abfa8031808b5268f0dc0e8db2 GIT binary patch literal 65279 zcmce-dpMQd_W$pEs!$|JQY0jy9O52hE|gR{sU#GUqA1I;bdper&_RVFC8Z=KV%>AD z5=kkYRFqVT&N`p@diJxg@BTc`@3Z$m`}%#?WnRm*uC?B;F~>dU9OE8y-m|7j%ScEl zC`kN&UJeqv5`JMzSFZ973-=ER@elP54_zF%(tFiPZ{Nk?egWQ5f#Ct(t3tv9g9D=% zhX?k49_qg=)ITh&_usKg7Kbl34OtQ7QtO2hhVELu_&+e-=BB~^;h}+k zVWvT=!ovQ6_BInr`h~9w^$RfdjtCD73RCI@`9=f=E%6TbUm3P4)H~8bNhoD#=q~>c zl~tjOm-!2&|NM!PQ_!l#;d6R_-ut1FP{v)d_d8}nS$FB)cfXJbq1-gdaH0I{X%a?8 ziGF^of`a`0{^(1X(#+5${-OR$CI9!PV4QGpfZX&_*>Is&uY%z;>1mR^O>x?QaN!Wc>CYv5|E3+BHsB9OXm9<8 zfkVTEItDUw`u*Z)&}WVq14RXECB{a?0+FAiPiAO5HL zi-m@N0ayx+{s}PtJ7DzR049n5Fn>(A(9~5p_TK?~g=T*NSP9Ml39$G(VBFsTW|5S4RA4*fgg;QOH)BZJT|7w%a<^M`Z^$|jhW%XS+2s_{~A~prvGEuy8czcTsY^SZkzjex6S*T+uZ&fS@Xk%?ykZG z|Es*8(BrT27DCT|%6t7?e&OHb7yT*k9WL~76)yhQWw$W>>%jH>D`A|_@1KMve<%3= zjj;63tFclMwWGLh#=REB_>{3KxdB3WfiU;3o|IE5TA2 z_D@3i-w6?aBSij5hzb|3b``FPFr76`ULpF~h}jQhB_#f6v;J+YnIbXfzlb%aOIL*k zFAn#P^bh?b_Bc5h{l}2=muO`Ehmqy)9&G+2F8tB{y^$^wvj2^dVSmOT@5MoZ%T@;a zuMBr`kd*MS8vlm@Bue<=k4f`K`}YPoO7#0L4DekO?!W4fnC9f5_%Hti1%`(Q`Tt=< zJUksN*USH-{d*e(5`F&*8~zASW6hi#6#lOPb5DDcKj>1$TGM+I5s5$Af1)e@w{&wS z2c`c@_wd*n>GltPO|0u5iO~Q5e2c&G->TpL2fy*u{Fnbv_~ZV{*WDGyRz2?mv4uXl zQRl>dW~(!+5wS#m^?od#_8etRRLHz*=U_+8OY&U#IBp!Uo&A1ZEgCRtILnc6V)~(F zm{od&aH0NWyIT#$FFeeqZg9c-=8IUYohP|n^v zvCW0EZ0z_*e)?5^Hv2{*al2SfsoHF=U^N4Z+-8T5(7TQrx7vfFi4nZ-W*M9vPJh?7cI($bPA z_;v0mR<}@%PCsqLE2@-}ju#!Yr+yi0TW`plYH@7p*^{{9WFT2$mdEZH-R1rCk23C! zGK;S|!pt_7va^j!M3yXLhu_AE8k%;JL1CrrO?4wlTHZnrjy_ATd^t?Cy@SZKM+3uPwPVkl*Rsu`y=IO87Wp=|OgUOr^A@unPKRV`IEU)!T5}!=%VUCpnbfjG) z8JOHjtG~K{@w?&leW?d&Rg7eBd@V?z&RqJ*UXdj|PQeIub#`=V1J;>zb7B4F)2)$m zXtiPl8PqkMAD-O6;zI6_Z6nXK`^M)XecBdA@@@56p7&ivT*H6Qa1t)l~XnERYSk8U#DBzPa zcYw`?0(ceF$^Y;gfifGqxV0VA;cc2BzWurz#P)OS%4;R?b5Rp7Q7^7~>{bIY8Zoxv z$tIX#V1)Ha&9G!YfE};B4 zh~yLC=5P((LQxS^mz}hGt2T!p<)VYRI~Ul!UR?mUeD3llIT;W$z=7sHf5CU#%km>` zF6L)w#ei&}0va}qq)LnRQSfbw$Z2jKcwKbH-RgxPae0R5(=b20_UR)CmMVk)j`MJ6 z+$QdUO_J!-6)9SJwD%eXX7IB8GeKKnvuN(7K3JUH4z7x()Ti5qR^D04#};hhOAJOK zU)suVDT?FYstm)|&L6-qdI#JwKF+Im48(KAeefEIg%^WnO zzt7??x@zJlgR`P*Is<6tuX=E*I1GELb?~6K1nv#jg15_C;lgx9%r%h#1N~vtTx^Xi zzTM%QuP?yR4<8}3`aOgh)xlO-ebJ+9p1^%P2r;9k*_qh(N9-H{7rn>eOOZdm)saFk z)j?R=6~Lcf8wY~eZ4eVUw5m;gpr|8o2#he50ncUvf=uZ?#6cc+uGs}y70PHRZr}=}uS4

& z%$n=$1Q#}N`}-iS)|v=iPk(aGM~b0w&qOXeq+OKo#fLBHMo#9nAu72v0F@sOb%&m`)&jY% z*fK{7uN+myg1dwHc@mY7Jm3u6~Uu6 zIn3Gn3-s06Ifo_YC=B`xpC1;%%egDSYG^a)l=Q%nH3uQpp&Yu*5+UZ^L+I?XM4e%nXLuH`ft$}7eBd|FC8aQ1xh4xA#+(4OJLq@4P=~&E1P(& zon~~L$Jh~mWci#@)*5O{7X}Ytlg$>f+pT3hCA&z;^;0ytJ)VU4DzWAXdGSC}G);=D zaFk~4jeSXTa0r_hC8jgWAJXn)K*rdt#*+Sx^tqxfi5qvE|J6E(%}OvQyXR-Jxn1UL zf{_wa9{Gff5Bx=MwK~$+fMM*h@)R;W!-BM}s$yqfFJqq0JX@4BgGF0Pv(O9Ym_koH z8Ms-2%}*}Fcr||DP#On-$}>tRv3i7^?MLNxbr+BwLV~O!clt z;nS~4*k@Ibh~2SbAs^M~FR5W-?}TRZQ0F_Ue~=Rly&O%~q&qO>_cLing*nUhP9b|f z>a!lNsccAe5*@g43VSa-jFk;ZCdN(H82Rf0q}9CFau(DL z+b#;J%-|w?CgHEoThVsRY(6GAAI@)gMx~iPRM*VHPWOsa)vow`{0JuljFS|>8GAFF z6*E-iqhyAQSsnkbD}~?lGae+;9M4%ePUe-44n`NF3{K}kfydYGO zn?1n-))miz@ZJ4UYmhd4zMTun)9az{;Cyb{?^ZtQ?f^8r`3#1nUE*p;XbE-_~`33pXBv#~O$7%E=F)v*ZwW;OBQ{SAv(Rt6YUycC^L!$=xj7i_ za=wuCd@vt(P0dnxg7(O%yzpg0_9NqBX01@U%rbuaPI< zufD1jVXPm%HfV-Mm&bO>8HP|(&;?fJCj76|TOe^}>4cDJgE+y6NRB9(V%h^8kwVh| z(d%{HBJQ1=MU6AjGs@tr&gCDIiTXGBc`l<~`Okn3G4!R+- zn{}C=xaKsJl*!?oey<^Mo(x(=X2aw$k3iMwwB6ar_Yi!Ihs)7Is5=~pkyib|tIHTG zZ_kI2LFc#zjT8_x555gccGN`@L`|6!#Os z!PqS@J1qkw+UmIAyHaSPFca@NpMr0_`F!oM6d0hJ3~9|0E|8lm5 z7E!}gU&~14vc*uk{U8~qX~^8gNo?uCfh6kS5-`>Iga?+|(o6SuVCt29 z6`Ika7&OOeFF-ri7c6Pbx+Z#q|e%(V@|+L|>he9uK`qI=5`XlP4zNp2i4v zXJ9SA*ELO)XEldyHq|EM<~j3q!)jT1X&~2gUY(U$Y-4HLQ;GGg-nBn8mbI2|U>0XX zv1Ok$nRH@5sziB_N24t9+w=-%Frt-yt2|B)50hapH|}OmZ;Zt3w4rl~0$Z|gJNa^HH9IWRmvrXKuybqou%#wmWY?m>>{Q|_aznEoe`Kwu%C(Q^ zyGRwW(voa&IBP+cS1}TrK*%DG3v}kEwRk9I6p>P2Nj?Txl8nM^=JX^LR~|ags-v6n z0gI&7m5s3W@Mdy%#RW9g)1t!X=h@YZ_vs|?W>RZ?_`zB0Evp-`0Sy7SQo+2Z&$p8(Qp}0!PnHV1X*dn7r@{&3t{I zX_H!JKVvHyTu@Kn+uKkwL=TR($D>RXV(AoFepILq$bR_5cPb5lg3c&@{mgQ{U86>H zac?V^UXafFx(w$J@%50^bRC{vzr<%1Jc11(jz2oCcYeChz`Mogd8IjKysF{XDk~dJ zgfC+ube}ca&*u42rw;RbM9naFQ8B#hJ^}h>-$lQS>bZ`|lOTN6LX?jG3RO#n;600R zuxP9nZ2hiCbEikxrOkL!W#@XFzhYg-cb|*_mpztmkEVb=8&+1my0Hta?yKOJmjiI$qW&mh&~Ywz&pDCHwHUa!yi~Moa28*3 z$Oj_)#89d?m^Ygj1aJCH0n>HL{DBG8Fuz0z4|g~4wm%l|ANTvg=Z%Nq*P<@Y#$F9I z?s~$M(sTSMg)MyFWwK~i>BmhyU`kCducjFj7GSd4P|R%b1@Dj;{3Y812S$D8^ZQ7n z>XD6Hg~J=@I_Affv@aJOYny=GJLO=Vz7b{&(?AnnZ!nA066IRP(znrW_{OsoWTt(A zyhAsExATA=qj9iY?m9F&Pr(*G8Ro_9;Eh{qU_`AX{q#;5b?6y*(_x6mrhbO`8-McZ znKj(>n-6)_LCR1*BA?%*w^*brsb=?e;dl6`-U=DxN5N#3MKD3Xf|q#r*6#SL1E4gl z5cUTQr*UfMxCgV}iApnH0-n==I?D;Dt#ewmtoI&AP>{t%-dCaIMIij_>VwhuR3WXY z3gSopf@iI_A-A{g5_Gc;o(J~9yWY8QvifyzKPotx)b9`!Fa!ID2cYfFSjauKPvl!s z06sdh*y(&uWTiI{V~S@~Eq9Lrld5l^;v|92dS;?y12%%$%y`lK^J!3#IRi>8r9@jV z)qwibFQ9+q372&H7}TxN2Z`>PReh&^1J&A0h|&B~DYsxWDtLNwJH|Mo!qO0Ksd^>E zL?webXMEMCzEu!n&$xS%TDbVi1JG{0%8l^<0pEJ-MTONWsGYSC$;fJm*w+D{4DCTY zHMtV6&4T5^Tj0L94q`Lp@ayIg;A<-lv*ea>^|yjy4o!nIzdv!4LX2U{2{qiBK9-AZ z+zk58kD#+r4FXTd<5G2LP_z@lfN%LQ!D9y`*v$b6YeUG94TFSG9sDG05}DeJ#B| zVL7=Ve!U2uyz7f5KgQsOU@87c%^HYwM9?$*41FDKaBhVfu8-LUt~27md8`|Htf>WN z^%Pc5{Rp3)4~Ob8D`D2V`<%~^-!Pu{hR{cYaEbI?XuDr(H!^4^#C%mo8^?X%vFQ@5 z4|fNd{Yjz|jyIvI(hTg}PQcH?LeN$W2kjLx@NL{a?k!yld-Ms0@Axk2Z~7a2)-I{i z->ZtLLA~}}@6Fjp4}tFK{qcf(K1k-6K;z8|cDk25?Oejk?SqN+k%3*@m2U^4`Jv+vzv7q5)yy3|Fnmo^xy6wjCXR7 z`{y3Qu|#K97jn`*c_rY*XE#%ZYfRAJx+qcg3&U~nB4v4PJF(bv*PPEShI|hfvcwA zoudN;!6RC^Ltl$Q5Iur<#2K>_YWLZQ-TlZ1+k5au%zToa!rZl(R(ezEj7GIYfXS#i}-Wqcq=XZ99Z>}M%sB{z<63;V`0 zPmA|#Rp3*6a;t}48m7aBcZIU2<{6~y*)Mo%Hh=}r^<&}sJE-Hki+JO`mU#TfQZ`P$ zg8dx-8}rIIDtld#`K=hv&*!F-xqeC{=fp~~Ik_Bb14MLJg_s2jmNNMtgVbF4aa0N$ z*YuVM$7zr_^TFb_LF4Fsn>@^Ei@-xY2{>-xVY0e>Al^;%wvUfwv(GK$Jr)xCR(kYLsaxnvUo#J zHoLqN_e8ED9yT|L>w+y*+Uq>t`EHN4lPHTSjUpX|-@w^MQ}9wgjHK1yq9N(- zGkhR>;s9a^y2=xLS_d3Zi%geND^^-AR|Okfz%% zo6v2`2f(|+El^=*%_S?ppg)wpL-*PMermfcFFv=q^4Lr*bZa){lfT@f!dH5DShkjS zU7W|Au)oABW%cL7jhaQiIi*zf?igzMy$LqnoC}><3Zh2u?=Uc1k`4Xd4q57zd`0hi zEirSiZGEH)<)TH@$K?=cuga&&en)A=hkiJr^d`Nea-MF=GlY!?euBcGOz1oPohZ$# z51qDeAKdnCgW@R~_(VSyHNRZqKIWeX{Ut{st;P=ytiMZbYMmifYY9KSRTp(nm5IiO zros`qHeRCf)&$?0EA-1UMG(Aw0M!S}_%F1>I_Re`_1$Eb+#RM<$|KUc6|H{dZ8fznaUxJ_VxPrt@3QHp7*wTYOu& zN0mhWC0j=i1spo2nb#HX5S>&$$!kSU=EsbF0;cjY*!tOxMqlm^J}Z`q78wtvzh|b> zZ(|APxaST(tu6wx;5O{3`w7{>k3?}S9G8BZ038~Rs5zjGD^6PkQ@4ge@TEdtZjc^- za&ikDm~Ib&Du=1g4`&=aS^{IfI@;dBYOK+KdR%HLCvm~RtX|U;S=`>sB~Bg zzsa?5ZM&9mMWc2@vCDiuI@OfEZt#Fkx|}9xB8o2+RSt6g1d}{Z(&y5fs8rb@h?ser z#+1}oF8EP^%f{@bok^KA|HKIzST`2(B2uV#u>ubEw?o0;p;)q2OqUOE#Ow*S*gy3h zZ*e#X)9%d!Gk0a!=}Nf~s&%yDQW;&HIuBLed*b*!1w0n1!Fxu3EnB)wtXJ22sm(D}k&1Ad25N z0^JVXhqZQzv}{K`?0R#I-ZTgX`urp0f4Iy$_2w`Kezx%LKlP}?H!CU?=YUzKU%;%V zGO);b9i1#3g1&8ubntZ!T}6X=)wg4?<4h-9E7;WAuh}$VnKn(+Sjf#g@Ca1)y5QCf zEqrp00f*F>6| zd-JS?yJ`N*Y`*i}9D1)mf{nnZT{knCdxcwX)ki_N}Ea?Zt) zkd6Cj>R`fx(g|)f*n{3#PiggAUGhSuh|HgzfW)JiSruJk!$0>Ch!S1s%crJf*ZLN| z@xCp+yP0x7-VYV5*%^jYqUvaT$xV*jn@e{Xm9enzn@LP;f0n5q!?qUvCNqXL(_d~+ zu=?fN%*dOrp1O|jdiaiaI$ck;olGac zUiH@G7kRPusj={L$zhh&|2j(=Y0NHgJ4tS$0WQ|KiOPqSF;RX#n>sz0s5~1>UccN% zqC5Htsub6;q#=(AsYYxYN(Zc8Mw2+3i~%W~Ln=u51d4KXVvk?9_9;Fl@& z6OS3SgJpP+Wba*fv7xio1k`jn`FUE6M4e50Cg-10+eveVi+`F+QO}7lX#DAF+T`U;hF_mZPV&oH+r*)K zdEqd6rQb008!t=dy)lLtUdD7s(Jtz?wF6C>B3aQ-OL2Hz1$o}Nf|z}5BI`>=ioDhy zubLNQO;(JaNQ&jokdcXY`?L@r9tKvJ5W#*Wh?!unCH2TRvm1TTh$YbX zJbkhYsYe*9X0N`sG~dX^V;j_D5)8U5gJ)fT?;7V?*}HPE4I zAZ_%{u)E`a5XOJFLPNg1u*=_AM5S%K(4%WK{qocw*L?wy@QbiDo|6Ix^%gQwclhz$FGW*FkEQC~!%#6ehxcA&h;fU5iwx~oavu-8qr=@S zA@uefI953pk15Tbs9S`e6 z6|v&}UJj%_@qRrG@b=eidS>`6lbF6%nKLv^ zDg)W59zK2Wem?m5PrC79Jvi1d-a#uJhmkA@c(jlEYM;#891rAXCwRd2^gU?gtAdBT z`qH^x^4LoJ>D%@~syr!?h6+Q`ZGs%{Y1;>vMC(w`@>aU}V=Ubu6i|;QIXw8o2p!hw zQ6KeM(EhGTeHZM61#wS7_2xl1X_61aY*ToTgijzIE=PkL@<6Fz03WQE!&mL-;Rl}T zgL3^GApPO@-ui?cS_L`s!jW6C`HnNs`L=UjWkDFZu!Qc&KfpT;aip6*-{B|G{&@CS z6kgn1BYN&Vk|sTuz=x{^P|zrYTf^hPv~>htoNSKU4`$Gn3HE6AP+av}VHOLB~FcBxiqsHyWF%PufsAWl15O^W_$=^+g7f zS~l>{7xl&GgI1#Ak#gSRo*aL9*=Z==rbbVFwTI>Y3jF506CgY{1^Rc2c@4MC^r7!* zx@R@dU7t7r3+5^F5?wc}mj#qk&CtuDS6d2b^qc^C@7V(C@^li)M&5zN+4*$tD_tzs zQKTQO-}03`+wjR1O}b|6bub>LjI;E-IMXlcylUep6y2A{5VtmNU5P7|`x4F%$X4V# zf1IcLXF1~Qyq(B>Q>0f$?WcQ+=HtGz1>m3(jMcJp@k{Phy4uzc*UUNrEtYqo?~@Xd z>iZd@rWR|uPjLfEUE2gl2c6)RgFVn)%@K2BM^Uu{F*xp(m>R6<<(eGFzINWU@%}~} zd|3+*ilc!Ce-pQ#C+zvrRMMDnmh4zwK-UG8kZ9jR{9YDDVy>nUi;Z0< zu-nf%OtNX@rx}FL;7PRPWVYq1KD`^^L3e%erTct8W6vUA67?X8j8zabJ-s2!;OcNb z___?6*!hrcm|#IV*CevpKjhiqgu}d^$p^aMAR3kTYiw8^qkE`SMH_NAvQG6eKiR$>0|}{-HCSUO)Tzz z1y4N|ut^qXWT}lN9e!Sq4K~eS0hL9ptUR3Y)C@7_JUXU((ERLJc0Fnk^UxcH8|E93 z%N0E2%>PF8zdR-PJW|MwBiZn4$#U{KZ4nViO~Q^r{dmt>9t&ez=*2sY^mCn{H=d=? zqn4RmLEqgtWG2Tv7h1AYn{%m|_XT3`U>PsiWWau>Z=h2p2C)8XXOT+rUJ^yOlMVXA z$j~VR$b$`6Sd-pa7FttAuVzi_UH{|R9CI~+#aAs-u6!M<6?T$_MRhnThOwQ}?X+j2 zke)Vs$lh_1RAS~hHsh2DKR-8+NDMl`roNs6XGd=#s;bf$U_716{FsRgXMbf$Ce)IG*gUp%^h;7|T|gRyWkhAXB$?b+$-L%|BQIv}V*A{75>b^7$s&e)yjeOr z`h+u^@d^Ud-a{BuB6{~6t2IgMEHy#C`495qYXjGFd5F02*F>g6_mW}vy~)&5--+&< zXwkCTBpNDS&a0o0Vw1-!V9xUrJ}*24I^@#fM&>zQJ%({l=cZ%MOvcZgFbum#hVmDd zPe3boXFj?!6=rmerzot?G@>#3* z{-!^E9`l(m9Hma59vg;t2V~$irMa|t`a@9YOruRZj)6kR2CR#6pi>&JL;Hy&?r!r# z_z~F`J){TVOQW$Mmf2`mcPSdKMx>x~nhJPlGf=#LmEX|Eh7V~!TGdu&P}QhcmLKGmM8z-MFmAU4sOp)a`(YIdE^+*#@<}wH_k7gPD=$G& zqKSTVIRgPl=8Ak#0Sv05any%6n)_uN?n&N4FYLAFZ=O(vrCEdFT)->p{56itO+3b# zJy4>>?~HM#cM|WOnaXWFwuXjh$5P$H2f%92J}MJh3b7xXAVlS=T~&50U)8)0j0Rr@ zrOYWbYXHD&jon<_hTGi4Fim=EKn35|a1`#DaD{4K@kXa4E51iIg~G=(RO7{L>ed#F zEnVCAkviG5GXS2?6(lL^H1%;m6eP z)N@-64LN>6Bz2?}Za?Di%O^D&6}p42-5i6>148I*)4P1riUcb5tF2qjYzYZw|&2`pv!;9Xt3){-+ zyvjSg4R=)Zx_mF4`q2d(&ZpDr;S=G}nd@}qyl$ANyp+z@D1>9uYw+L&bqZVOV4NtG zc21f^ubn(bosu?TbHZTkkxim&EL-`6EOngwY%6_nER~NM7mslh6|jGNJa${^(Qn^t z=0tXMsoT7p1pPQ zEV}W2EJ`~)hRQk7RAe-n?vZbT%h&tiw>JUsDPt7wUq6l}`1iop&Cxh`_c#7bVk|ZP zR!g3(oOeonsSY+lQjm59zd4ue^R84q*7X>Y#;nr*D6)Sm^ z^-BiT-ZOWSQJM6G)p`18+&+5NBbFX-7)ljR52mYU3NTmEhSQ&w1m0_}(FfHP+=%S6 zG;;AO*z9ZY@V0_Ge(X4!P3|~ZPj@4=SH78y4Qh5v()v4iQ?|Trq%p7{e6?{+3A^w(13NERh zOFcUtfO-Eo)ci1z3T}4T{!u^geUR*@2Pj2MMrnBU} z-cEM%$YgYOe@AW1uh=cHenkxv@3Mh&jv-zTWHXB;`F=7d*z9%4Scr27KT;)0!5sJ_*3w&bQh+2!ia zE*?(CirEK=d$%mUOxZ^dxERo*TlUhok#{IreuOHUN0CuP$U0Oj$j?R}{5^OZTUT?D z%KZwZJJxFwiP97NmJw6giXW3mjrBD0`kW)PzI>Csonb**XRjkKZ^e)s+sAPs(c=Zx zk3aOr{)y1G*otVy6%fmgFlOG8KrU*WB(_slk_Do3Ol948>bx|CcnZ}}J0XCbvr9v> zre*AHb1OU1l|fECawcIn^x5$neoXc0Siyk{quJ;2U+Jxed31?GG4*0A$dYvmg1l}Z z)}mlCGU_bv+Omf(sg0)X$2JhnJwV3Czs8M2){Cw*p5Q-B(WljOnpn^D2`r-b%nsi% z1I(sRBsM@Yk-jC#?mi- zq_nR#y*=#$AE6h_bYAGP7m}%L*n!J9&ny|wFOVmn-`bM_k3XZ&h|Nsnfvv#asSKAt z-OchRePqi%#gL_22aAR7!#qrmZe_csADq8 zOWe%F`=+v*O(l48P#<=~=v?pf1`>jsM#`jd+h@F_PMA&e8S3_OA{*maXN$i3snnnV zW-~C0jGf(1Ej}M3InLU2cEfU}7+uS|=kKL^;#35$?Tv`cov&2RA&ESyUs|QHR$J`* z*qUf4A0k5!?Iujwq4!F)r+AV=HDJScQWJ4yT(Y>JrBSY1Wa_Pdq&#LFAmDNt?I{ z%&=t_$uTWtx4esJ!yYxZ_;D;VTz!_g4H`$*m+m83m(AI@rDt*2gZ+#a*bB}{7m%7a z@2M)u!>4ivWWY=>5}+eb&J@Eq!nw7`smvp8oYeCWnm)bs)PQinE zYe`{1IWyalM;-(w;IQOGT0m`x>484X`Kk+Z_nfYELr<*1LC_?m&=L#g~z}B#JG8Fc=+2SW-C9Dt83J!R*?4l5`wzq3*_Cz$0#p|W(R}Qnakaw?AoFN zOb+#9T9c#%mn3sY{Pz^Rs_PHxv&N9*&Qa`}$^+crISIRe4ff8gr+leG z>6^|ZZps<R41HoX$ zG`dh#QlQe4Mv64TS>~XL#Lwv=+Zd=Pc8~55y&XQ5&>m~Lv;SIF|1p*wSDH`R;^(mP z$w0A!+$CyVsz&QlqA+WL9si*_42K+_Me>jBrTVqiBv^4Ysn~DN?5Dk;VWffCb2Y@i zk2XEf;m8`+mBSC`OkhnHnajLPxsf-CxF=29$Wpd&=|=Y8 zl$vI>yu0tNvxaygtzFbK>q1Bf; zE-uFjw@#C}d8w@9jt)6>W)vxSt3bA-*^s@KO)Pw~JpIttJ9kCxRBvVhc}9n`?{q)$ zN!Y_@1Q;Nvvz=VKxd|PzY{+-NBiOYxkqoguf;PSN=PzUA1gfn?_)K1wP5IVFgWsyM z-{%gX<3~NRL`F*R?q>@b{^lrYUe2HDlHYY#L?!O#vc3f@U}NPVaGe(T}lS7mJaqkO9RWDmPhBZ2oV zGqA_qT5KL9V9KzY1=-t?C+oaoKPbl(=R1^af9s^s(R;^SpZZd^59)(m2G8^*9vVarLw z%_6eWrj$r_^&RWsfq}HF5#_yy!65-hGAz#FbJ7loO9%xgW#kPX~eAVv=;c z6Xp-=@!Lr4m%sk*18mO_?0YC14Gh0 zRh?;P9%oHX^;8ZoF*S>k;wLZLaYe=olz8(KJ`CwXuUKQ&^(%tB3Q{7^%FdJWsB ztTL=0jq7)cJ)b^6T((`vrk8o6dg3+mJYoyWU2+q8PIyu4nv>X&Cj&d%Ct!VVPBB@g zo-&(ky38V-4T_Q@ol*5H>)vE;QEnE_|M7sWf1FC%mA~V;6(fnMwLM;0V9i|iKVp?3 z?YMCLZS=JDWzM;6)X(Y#>ntjv=d(An`*p@_%OOoce$#li{8bAKDceoeTo2Le;GwK8 zQh}_HyG+|c+E6@vJsEvnpJ{E4WqySc{FPOU$-`|jEHG~}$h{dLSnglN6kB}Q?xU^L zNM0a_81#eOIM%=#l_rsTTOU?%`2fo`MLt=t}F%ipeTbW z=EU<@r6KNKd5bCazm1Z|0#W;x9oXxR5H~#65KrjwVgjeVcs6=sRqfp=fX&D0-L!?I zDItN?6|7<(T@OK3R0P)DFChMFjfu{aSUk|zmJZ*WOla~gY@8NDMB~Pgf#E*jeDVW| z$y8P)yN(^I`H0aDrr^5Cl3LqTpzIn2Ix4>j$AwF>u`*YIY$$iW$jE-_t1x(%6v>V`d_DpdHuvuFJ_R5Ne@^~h%ftA zkw{b;*MW?B20n~E%I>Hbh!+^0!gF&7COp(&H$T|3##x+T(yxu=7z*6kO^d)o)mBY>L@C%S0J}mTSu`VV#D-j6e;iRpP{VD%{SW^O-L*GJ6G zz3pIqEN1qO6B)s!KD(L4kGpJ0Lp_aMdXQC|>qo*?-y!YGm8pi32fM|0^N}-Em=?E( z_q!Lx&(pp~UUynD{XTB=vU~3{2t#}5$dlKY#}aArw$0&WK;ddSe?bLAdAYD6V^wx# z?LY4q0|rB|nhqSznmXwKx2_lkc(Q@*VbJT_lmLT0-)& z&tF z$`B=qQZz{eNkwT8>OTM9@4ww5Kaq=wcVn}xrf&n8mu0w6Kal(vRbDpRxt|BN)lY2SIIX!tcYpL`65 zk#MNYD#AGpVzd`yI``;)xOh;5Jh3svR;u!tDO-f=AJ^mVJpnlN69b7dG4zR29H@>( z)0%A(^wf_~{BYeB`hLMRq_sYae|`?ak~=QZ+qFUX{fICv#s~5AtNZbdzHwqYXBT)> zWa8_K6?iw73*#G;`lyIeB@_z|A}Iq6thle6^bG4`*UjlzZQ*^^@wWlTq$!owRb+!c z3B;cj=aFNEYp_aoF1fW(5O2IGf+*Fahu<8+raO~}aDXwLvF;Py_+^NFJ5WL<28sGk9wt%n z#Hbw-H!SD1jr)>uw^P(|BIc-{m7s5FD;P{mrwi?Vk|{4M=yQ1y&i9|EaD>lFOz*do zyr%uM-1i4rD;|v93fpMT!KV;eK9S=#C4sj8_=B(3Nb#;-K4kyXuC@xux0Xv=JVuJ2A3S=rl9MSa^rSpn0*ze*gvm2>g?qp8^KVqQu(X+bY3Rx zY?tGSv=clqG@CXhq+_GvO#JQQCBn7XO;z~sESrl;I65JpavoK}AarY)#onK8JyxN78 zPS3+bCyvs4zXnLdum;T>{YsQXKVk!p1+5CtpmX2FfX^cf{L{mTQ}IoRHOyo1Dt?|> zUhpVxWQ$p{yaW!^aYMialz&>vAF7CIzcmjv8#tA>D>n&7^_y} zVy#B(x9b|eFX@RJp2*`_fotjflSU9Um!Wfa|H36xSAffl1gvC`O8)q+<0nmpX=s}z zRVWCT9eqeIdphw!ojIjc}$coIZLU#7^z%VVxw8(BO@i z@X$|Fn&)#G#5P7!o1@Xhkl%}SlxkoT(xzi4(L0d2MuW$my^2Za7h)G4MHTgw@qry_ zC}QF__RroVkmGl8OASqMz?Vuar`=1=rlt~Qkq$WD^9+7S9HFi{6FI(<%Ba$WY-}Va zLlycbk{zKYG&}S&eIWFQ`}6W*dUAak{gS7|sjMiX;zy>STQk;BgWyvzDmO+vBjq_Y zKT6RG>n^;ke+eFcUJn&o(X>u`m|gvRG11LCkCXy*sX?Oyz0$v#yz*(n-MoofJr?2+iZa~^FZTD@swJ#Tcy2=Hkeo5xI{Shp_p_;lP_>5Zo+Jl9^N+3zg>F}`X zJwa`q*xT+pao8wHl%8Fu`(ECp8q1FG{Dmg+cw=YDRQ3u@bX(0b@oRALrv0pl=`1QU z;6u!9;$V|gJGrty6C|$XL&viU+A>9#uCFwtZ!UhsWtO+e(JFC#*_KCJx(F`&xS4(u zP{Au6r(?W14FcU~(?^!E_^9_D`279>tn#=?zLy@CWbV`mY6_>m~-+t4^R+ZzVSF(&|tU!2B{U|n7nL>}MRg*_) z@6ho2IQ-scIx<~UiPCP5;U15*So!iJSQ`0Zk)}e|KI#PU&Y_$e`J0>|9{WV z|M6S?@mu~A_$>?SvdF}}X(Y!)4)<%dv!6c}5rb7Vm`VP|*sS-ZA&pMFB5@z=*`I-N zpfe7xiKTuc7tm7)34BQAB04QoPd7y5;o_tMc2z%x?wma^&pV&a&QhX6T?!Pt_0yW$ zk$B0IWvHZjDlyq$iT$tg_v}HoSlweX^K$EJ@?jwtqF>F#d2b`hy}e<$cC8vYHSGm0 z>^Y3x{z_nli2^vV>lb%m>SS(jZ+8L$(x;jKr`ehEz~il zaND1DT(PDlJKFJ?WqLf_qx-04z-79N_R_0Haa8(8139W8PsR0@;SC8bcFt|Pye57L9;P9GebUQ64Z0o4R1jiWW!vuIWsg+!ZQ zBqF>8!d)d{m%Al-UHS)3(MbF*!h?y|_ktDAci_F|&iK*DL%6q7ns-*Q4O`Zkkg^$R z^xTd~IPzXS&P;g>S#Kif67eHslXxH&czzG7pFe`nYS-cFrDxf7fotgcQ{GUuK@DrF z#=`rpxzu+&mLPr*Xn#^Z6 zeL<0rSP+?9fj$Z)k(;@DX(UKfeKZUpKbQKB$Y8tX6uP2x3+}LaOxpKeWYso`;ceR- zXw`@xS+lK_u)~^=>#fSUv111b`J~7ZJylDt?eKvc%?8wW`VA@}cm>#a8D7eImOSkZ z#Y=p3Kr3i2PwB>Y@_U^C@B2!M&zHSI=8^q$U*tJ_*2JH9g&Nb@*@T#}hj4GP2yfr9 zI*x75TNsWwikPyEz;m zl0Hto!}#}m2c2o-jS75^KW{vC;v6RZZuIvx7n);d0kbx&$Hp#J`1LwXXqSwk+OgKu zf7Ai5sMlq2Yz(}7@`?H#iY0sZ*VEnc1TXKuPg&#y5kMI=`y5Utk3#N=@#_{^Qak)35tjt@lRq`78zrD!hB zwwi>PR2Cu^){qg&#iUGB2RD^orJ*-J0Y@SY+pgV&?}UEG?TfZi^EvNmbld})zW6=< z5Eno-*SI0y4o7@LFbSL9a-w~*esoE>HO&YvgqF?c>Ee7d@Vz<}*vL%mzAl4=2|JLc z%MIv*O+3C@sm4{_CXRJv-m=#OmFd?>5!k@~8wkZF;p+Wixc-zgY%>p}bLyM0X^l7S zX%S;xca4&>@&VY+et>Rmm|Y1!K(+% z%A)D=cazzOj7@y*$VH-FdYb;&(+vOkE&uo}{|WpSU9kWYCxGc@!(o)xq0X6_ zn@cRDmJ_d@-^g~W8=e1*M{(m*biMu)tFIzRqwBnhafbypyU;=|O)sS(pCh387-y8haQ-=l~!Jxz^N@g2VceFv1Yz1-J`aE*7==f z_n(ZQ_uugEfKaXBaAE!sV zP&BTLPYVlB_uHRAs?vfw_x=L4iW#&vWHvprK!cbqxK7*WOJOrhH)zz+!llU$bU5fd zkto!}n*`k18X1H&4?Kg_MzQo1|D60)|06hrY15_dCy4ZfW-K_O!4sePp5Dvchrh1g zf_8kmPm}*<;mYAnXgpa7w>t)sCEGWF={!L?77;@R_`Swcq7hi>=qH$?Cdk{J7{O;F zd?AO5VyH^57In~Cfz_)k=%{ZXzO`u)xBc*7A7&Ge__ z?FD>lg$7j^ze3dR$l#i1JMm`UC{+3w(f7~oG4vbbS$f|5=RKJhJ>x2qGe?U)8(4{d zC-q~6c8cZGEO7hm92hQ00lkj*`0e7&c%k$GnG~Kn zkqX^|yc^j`xV*9ujfm~TZU&p#%X9Y9=i+ntbB~8~>-|4;cc(h7JY_|@W>(TIw-(?! zU5qu>Ji#+hj!<2JRB~yfDeP^sf>nAyY5H3+%(Fd@uYKQ3f0$fF42^^7<#c6)9llqh`>w{jKH6!4zU!<%8P{F5$QHWboaS{Qv8sMC@JglbG+=hU07p zSr`7S>W^s!uJ;R~@fWV*-;c!br>Y!sC4ozezGc(#xaZhp{yr(3P`Hae+PDdZwkFW0nwxR1s1*PITaR|U%cdz}@95G36)G+?NW?E3 z#8w5?V5=<+eo-wn;GjJ{bSx0t8|dOY2TQ1)(G{Bc*nzyUGp3!bKG;FHk4Twou#=z3>e^fX<#GOVG*1x8_-?JdZuaX|RKFky_jqrTvDRV;B7ll=> zBa!Cibbjt9GW8UfyqsP|2cKDSYlZpuF`Fmi#i5S4%1NE1uS>?0JzZ(itS=zm9YIJ( zB(95b#bTTH;>(RC_)gs(yd_zc4m=rxGaVNA?bdy`#mj-2?zj~*#xwAo(Hr<5zvUml zv!L$o-$pqb$%Az`>_Kpiug&$kLffzH;X#dl_CpqQ+#elGvxb7 z(SS{IJP$`3EX66Hmd|ooart`u&_#)G6lbI%f@f5bf`j^J)Ce` zfO@bV<;aTS;?i_n!mXx4cfP^%4fja8JU@45{lHRa{vA5wOfh{w(~hP$nzQhieV?h9 zXAg}3q13EE0fRC4$VbgPXZjRqCsH4=hzj3sGl`9b>CU!2+<9>zzHt);UvmeKA0 zose2x0vAosQWYN+&UfMs?pytF+v|lmH!~UhK!z^5Eec)SZX_c23m-_$##>!4(6y~x z`de`Z-4J({_&1%#SNngFKDG6rXtM*{0}be#6i2#7%>ZO+9o`(u!&^TOL)qzOtnsA= zF3K*Y`_8*j4PRYSP#{LHNfv;zFbjGxrqw+?{YW1(TUPc$BXFU$y51z zoK$*4vk8{JQsBj=j)6jlEB&4)!6|;nfYRblxS?P|m`XX0+VOML`pqcqc_YX36*1s_ zims*SXEfj`I{Qe_jd3u)R6-B&c91W8X5jZn=jifDfAEo(R8sd+oLap6gBs`Tr}?w9 zXn~wDc#H*O_J%8+eytve+PA`iNB%f&sw5|Foi5!l5(`63#x&OWI63>PleD&& zP<^`uTohD^rziN~{HQY8o6!ZszxDBB>3X`$u!1aj!{66*z9;VeBD@`zQ)tS}SloKC zk6jQO0P?p*$-}F2aDm=ddO*9JoJ}r+l&&(m_Wml8;oksfizD#E1FxYk?g1Xsl;XVF zHbNe~6ro###W-4BLO9P<8sk-0@x+5V`00!`G%BS`Z)VD&8XQYHBPK!qzErx?kDpD~ zttRE)%J5m60p{{M1d2{-{QJ#P`pH{>qw~lUC+te6;m5l`bh9YEQhkOhtoqBA9jGAR z76#E&cjJ~ z?u>8rkKtl^?#u!Fyl)#ExLZRaALh}e&QIu~%!Rxx6?0nk@E#b&MAG^j8;NAL7jD&^ zLGP;9l0d%)_(_TcJ4I<4-B_DKwoDc03EqprW7<(P)#@;P_db+#ebB(tydgAiz>h}o z`@CUN7PKz&DjYiIk5`8);Ht!VRHb_f2{`SG70C>2vZ4$-nLeW#KSkmGBmQ%jRADPE z_L1**UC9Hf2%38B0sP2HrQf!N{@48ci&Jpp|H%A&!T+A0|Kqp(J}4IgQ%)OvXANZ?HYFDY)X~1GrVIz#Dt*Oss`h&?irG!RzcS{O46ZpU3l)W!M~k z4)Plj+wcSmJ7b~jlL*I0sh7!)U4ed_38ksWPtYg)%)=|Ky)?kq9vtV^vWu^d_$92orp;EG2XtYu? zR-NR8<44`_rp9*Azq1ktH+NDH;gG60516wr9=#haAq%6bX86X7 zpGVOt7hlmo=AxVp^BlMqkvr+Ajw)8u9;G|yw9vJSt4U_845u`N|IYFgzq&UI(`O>&{2j_(#^I=m z(GWPegcfWICiB|wVc(deSiX(lspw%z(UpVPt{{r$_6yVRD^kJeo+lRYj;FPL8nok= zJDxY|0Bs;4v?NspKlic1``sL|oS-M(W?4!OHnh-zsZ$6yMVO|Z`$N+_hoJpW7c8*q zg0j#)GBMYl{!Hqok~%{)?~FYDt&vGjF^z29&;78wKNAn_5~7mRcTgdV$!y`Lco5|G z84a%}Q>T_&c=6unY@z%|uyou{pGJG*30?@le0T<*s$D~8&gXYh^j+|++*rD^BOl%1 z^CXU~Jp<4BG_d8kBz}JUF*^2HiFd61Dya0Hqgj{vSF%gr<8!|)(Hki}yj-P#Ouy4d zU3AZ($%gZ3&O|ZJ+T}~=ngjLpeOMHJ!_Vpq2w$a66^X1%of$E`Cyb{lnxmcRO33+6 zB)(PT2FI!fNlMpI+W4V>{9Lbw=eRty;#G{R0?;)^SiU2x>DLP!m) z2LB8(swJ#PYmcfz5T7euq&I_pHx{dz)h@x=@?balvEd3`Xc!OLyeas%STvQYc433( z<)Gx}mh|Z|70!tq0g@vTg=h7S5z*!R`C4-UF1VG511kDiTU%>9hCJPW(jHS!|+tiXo)KIuPv=4uBNj)oAONLb}?(08h&(gU^sD-w?4k7S|1)CvFX=;Bs^T z?$LP*uM($IHgywDrv>~q$MC$Rld-y;0e+=sPjmQPafSZ*NYfkWk>BkkLYM=BTeidB zCDV!dbs>(uLID29Z~4b>`A^`t82@sBX~!m@sJJ)uz~uY1S4of>Ki@{e$^>|SHq2&^ z?)-)W+eb;dt|WfivW<)^XvJ@CyugAd*MjOMQ+%XBkX|)i%Vr5UVjnRM$x_?{B7f5O zcbd*PPd67_cqUk2Gzg0My)XY!z{lj~QN(1^0S zf}Dl?EY%k8Q~0QJ2eurHg6sM%r>B0)C@^mNqsVd7`ub+(r zQVJ-&D~p5f1khR0E1+X*99w!qi#+{mLOVMG@zLpK_?plXFqQD8!c(V0%H}9)-qeEo z=JUIl^F~Os-5sK_HH(hJRQ!JI2Fc_4VZ{8PN&H@fW$6*na{oXp1JZHR;aybkjvH3B z?Iixcwqa9dCb<;917Fv3CkraVao0N|UJ!>rtIvH1mx|i4qo)iG{SZXIA5KJ$2W{!b z!-Z5>aTHY^6yUTjHpjO!`tf4ctVzBMyT>AIUCF~+%ig_0vqF;Hv5aCu$ ztqbefnv|35=UhJPdo%}6atnmUx|dWKTEI3hj1S&e)TkiZZVn3&tLq7%yELp?G@SKKv9maRwbUWA_{wE_~D-GsU+d} zYb=mB1Am(kNi{_f68)lqi|5wE#W4|@6{<*i7saXR#93G`8mR00lUP4@K3@JQ0+;CR z2G7`AG-!(<$$ct{yIh4h5^M7DiG*Y#DS48VtZv5H{QbZnKfC^W!$R$301YzH1}xOaBZ_;turD>KHiZ{v7Uc zR6)+*4Am?6Lb8t{ol~UW`F^B*dxLRgeFd?8k;rwI z-bV#IM6l*PE6k^wuun7+adUSN{_^uFg2CmaVU`%Sc`ShC18U&r`TbbSI}vNzuECrs zZ>pzG|4Nf#Is2}kY`DY!79KlJ z0~YTgQe!nB`qPpItWU(8pc1^h|12({@r2tW$mva0h1^%)D7vr|J;TRn_|DtZ-$)bL zYrdd0T|04FcOw3@@Fj%YVM&(NU2rd)j&%%omn>D?x{IEBx<**^f1<4SxN zO*z-Vs0G{yf=OPK9@sc6qyC4bX;=Cwil=>tgN;FW!(b=P61z(W#k=XEv{rKByDf&a z3poD5FkWpWPeUV`@zb~b3{8j{_{JCD$B8-A`_)^lXnF461zp%B4D7ctxBl2VKGNZW;EL)TXk>FG2Z@F%-M(EHtkCPAjau z=;+*9`m09}RvwGz--90j&4se~_oX7*qS_CRDiPpVIsiBPtw>y)8s{^&nl0dH(HMn3 z6189j6#89e8=Cmt=0DPS#q|uxjx;6(S@y(1H3eVg=W1=Tlc=G+EbUubi|5~ag_OTv z!E61-X^28C4zsyV3U_DXu$?NT>6;36QRFjsthA_ks4=d8c!wR7Pr|m&@^pP<7@es6 zj_iwghTnE?z`}onX+d!%wGzri>kl@=scOPlO&`UB3V(5bY8TWEl|pdxE}S@2L_cO| zVXq%EaYViko%U)j{qb-jj(;bQ*6$dg^HLAcnd&24rG~vUZMPtm{q>m2EB*!JOaq*> zcoFOjEn;$+N2n$58cFlNN#A*2CR^^_w$KkurI)se(I=_WoWY}tG*&MK7rYba?~_&% z_||Bd`T93e;l$DK%vmt#sz(gGtkCS)t#I1Y2R6Uzqki{7=+Kcg=sNHNtXNIraQ!zK z*Q$pXlTOkYe*fQOYz%)97^2G+3h@fD`FOK?4{48Nai>Bj8_^Jh`(1b7x`pM$t0D`J z+zkPLlS@=4Y&jmi6U3f=?T3xd7(;eeI92#1j(Y`4Nd0Xwa(j;-sfY>3KBfY=&?OGb zwcesx4)dtoiBK9OEJSlhba-O>-?MeHHCT%I0pspsJi~+2;kCIjo%44eco!8>j{R}^ ztvH%WbUk3IdK0l^0K&_v#W>OFO=N+cD}|vL_{`_zd@A$6{rtU$xJW8Kb(D}yey;!i z*I}Y?`V;iG>fvthF>d{`Ch#~CO`j_+qXkKy@bp44`mj6#*SuYTPqa(pc>~*Uc6%m` zy6lYw6qPs?U$f|mo!Rhq!)iQDsskrHC3Mnkgr6U{iYG}{vaJ)Y(HPgSh>2sXTaJkGjYI0FA}1A6no}9r|jvM#6c#4 zT6%rQs(CiFp1h+I=JPY7CwGE$(*XJQOoXQFZzhrlcOt7>X_R$c3Kwc?;Dm@PXKKGN zPxZoeypTq)s}q~RHvScDO-Q2RvV+8`$CuuidWP=)6h>|ftpewi?R3+QO02F^57!D? zV3DRdhzczsx#j)jtC}C<;^2=DM_S|E{JV;VE%8_*EQ>@e;o@VpwIt^2WJYpk92WY@ zzt7jrCo^rM;AF~n%>KxOD=rE6v9<%<(Z3CU;j^M_O>hw92~T(f`Q&-28vf&kvsYZ+~l(Z<@>hh`)5<|KDGlrt{zaR(B0=>HqusR{u{1 z-T$)J&gScD{||re*+))FeY;uHTf>+=P+|Hu0O^8f$x|Nja3|8s@ntU4h$ zzbvrEe{-AVUh@v*6|4pspGVP=jTUeu;WwJ4t^pu*9$oCa$7b4GKqCw6k+s1|b~fKx zDg2x7u^Qm}B)TI}uvZ`V=99fl&=|QfpDMtcoYSpO?dfL;6JR&{W{vI>o)UF$q~LwFMueOf(pFoICJo zBP1{Qhzi4hGxs!Ebd$53xv5*h)qfyDe6_}rU*SdeLu>@9u)oahIwA(Xdtag!_gSRH z&XTNY@ny|w&a%NhI^gR0p8aw;iiNe4q3iSq^nsj1ZwxBg2uD75jn8bKt>T9sH9LS> z%|bGNnh>iUGgKom^L%v-zrU7w$Q|7uWngpdL!|w(1Zj9(MP}Nw(ao2k%#8Lvl;Qgr z=?ttz62n`NK-}*tryuj-*r&~Gu!92lj<;~<*C>#-W4F;^e#h^efg1cd`xDL7o`?eM zJCK0K1UPZ@L~Fp_4yAd^8iIR8|QIHA1+5TB$AQ;%)J)v z{!;L3KoNO_%z)J=H-pA+0ocV+gcu(_|8dL@2^-GrslybEs`;m7ll zr|_j3#yt&Lj`W}(dU6oQK4NBfW};M&4(5S%J_>3dM5`ocL8E3AO4yf<+>39c*(zm7 zf)#?O&(X+hlQ_gZEn`kf+(B|CuTfKX3F>%Xg@!FFk%OoSoEg54n#3}Yp`{tKO)&wj zm3of?>T}RqvyDjVRxsaHcn6K=@sPi*j78rX;5*SQVB=;6%?o+V2#H-q+h>X}sYL_a z`E$3jc`rnn#$-JZdw2&qjLD+AeX|fp))bxoqlUid&V$-zr_iulDYM8ik(+s57N#D* zg+5G?g~1Lt1AN;>z_O-pCwqn+mfuol%=bC(`?J6=^5@LV6}LaJycEm6xtW2vMXx>j86Q>RRSF z*^FwQ2ctJ-wW#xfH4@nQ-QwffQY7Lb1pSMCpv<;XQdh$nwi*Kx-Asae08E zqleMd7wM?^%{6paE(i$^KS2xb%mUTNV~n$l8Tu~r7#UfP@pC@QP|D~`)YJbPWv5o5 zT{C^ry*Z1Jui;!&7quQq4*o&ZdNz!Ut$@sV-bmq8Au2hg0AD6)!4&;^RGl>$J?-V7 zVi#F>k`a#lr|DY!%m4q&|NkfC|AX($g;8aKo~DPQu@@4=HdhID_QWBXm@V*D7okGW zVy@FQSLX8{4vdV4LUH0I)_p-Tn*DbX94PZ50ymFTXX<#>=tx{2b5X~mn1n|3J8*9ooyR-eL6!|j5f=TzPOvS6Fjg(g_vXGP|*Xvo+MG6cHV z?4_baf5?Z=6-{F5oDJbXrzvY6+=l#(PB0pcW2|$}au5~FL`62A*p#sW_G9>CHf_fR zws|q%?^1q;5lYe^d%yIcHFI^@tQb+GV8}qgQ*|^isgavvH^i=&U5~mfT-i!<9S{*1 zV@7;UQQ?c<$a1KCRj$<1<4uC!et%d>lBq38aZE219A%|3jdYiOIsgdP}277}~W z_D)%-;e15nTyvzCJ{M+BC_{T1HY48wV-zD{#nv7F%|!f-V(8IDsI0IK`P>4;%iLZw z%y*=zeO6|l`b`GeXBN`M^I*AF(alq5uzu3wmj@*2IhIInJF?ekoMHXnn!&1 zPFUnU)=D}b8R-ava>Q%2==mQswIl|;P3}Ovy&o)Op7kO1;tKxy8&P!FR!~y^%gpH4 zhood%bTd1GS!EZBy7!P8p$|RiIwYadHGybY)F0Kpm;|AvMab!;EAvs}7INcIUCWW1DVLQL5a)rI&0Y{o{qGgwSftW zKsRYO!URK&EuXP_Cl@`B)`$ALS^RHY2HIeM1*zQ_L<^^zpv+4~jOWp1%o}%o;0EV1 z*JZDvAN)C>Z|`?z`37Yu$&rV3k+;y@SP|gZi=%yQmr%qtPo#W*0>~M!g>)qW^a6^J z+8t}8x>gQKwkksWq^YP>qYeGFJ;!cc9EGBi8qm;TzK_`Od9~RS5l|Z_M>#wBdF2b` zsPKar2vk%ub@m}l%zzAB5qN;27PH86paqSce#Y$&--4EOxG~@6B%lpDRN24BY>;=V zE^riVY6=~t`0kntI@{)@i?YoT>bNNY*GspfS&z;jCTt?;x4l8(r=|Y+Th0FE|Nj&6 z|4Ygg(53KTC^fc#`RBw~kHxYm_ZHv5?I?u4tPErYjaZf|`HGwOi85Bnv8-m^0=D-B zK}UsZQR4b`cC&pu>iS^Mj;)Y`770gwF6$Qi{xxDbrSrIXr}s={tw%0tT06JUh*m^~-GnO(wLfdY4=BhQ8grsA3$NaN+`UMXd(iX++42w^xP zF^4JsIGb1x9U+>q3W_IAhMt)9$RYI3>BAabdDu3|j{jVib0>NSl4X0HLG)Q3sy6(J_7ydu@y?YfCu2IwnRyQd zD=&tIzy7H6ZxS0iorjutixAs-DF}F3j8^sP!Ri(LNZo9VJ*%9-?Bc#*ql?BFtHgZv z*PaVpo>7&BQ14eXF#jn^{-ug45?3O@B$nmrKeXI!nt(iPve1+(I%Ipsc}7J1CwHdV zDOS&cg8E&LQHfqXv%35T$|#6LE*%p=WA1v`P;mo|-Q~OC%T1Ahx_8Y=VFrC^Y(nY9 zw~)4w5InjdjpY1ZG24$hGfLVXD5N_C@$&&p{$yozra%t%IZkGl79T__oI6l=iYUs@ z5P`fUy$t$LisaprP?|m8N3(JXI%cen^gg>4+Dz~% z&qdSHTTzd>0b3WG$^4daLOym?$n>oTcbd~iqEr~kU!yM~CCjC# zDpCTD%rb(z#v<(6?_Q|m`ZT09*pCbs&Vvib5EQ>|Kg#QIf)8yxBo+S*&21S# z=^IqgpLln^uegR;`28z-?6D5LpnQ+i#&*RD1xLYFI!QM~byP@g40m$*eKDCJ#yDY-_BH`Fp1m$jwExBB`pW#Elq5et0}y*&t$JeE&!=EUx?s4 zI1Gw(kYh|3dA2MNCA^SD&tW%=m{a0yVFSzZg<#C4o16d37aq^thWxUkN$s{bXba02408Q1yH@+e ztHa?LTI@IV4w4n83zKKn5jiCugj)md`RGeLwi~yz8!LaB^qO z+6TW;XS)F1y!bsrN6T35&^t!pgQjJ(oizEox|%Ez&P9P{m&lTa9MU$#_qQgr64XwJ z!%h?A-}49g^JfF8Ogq>2vp7);i${KI-y+xWNgz2R66WpgW?J;>(1~kT(W1gFXe9qJ zITGrG{(Sai&t04fsqw{Jw>>fF^si?u=k_%VkD4ZOsz4l_x9dVr_7}2pZ}TnR#a0nL z!&oGDq?vIwsb{m)ZP>$Rf@JK*AMWaE9$Oh*hZLKl(9Mk=;QQ1D8h4)Net9#7OjGtU zcb!(C$owVjiR;giY3xa?2Zj1YSLecQjX+-nP zL=tWqO*lV~)wsqPLf`=lSW+bfPAQ)m&#V5-xwi?3G_7PKV~vQ!Ll4w?cpH&_0pR~@ zFP{|t1KA}O@>%RB$%4qs=x#y@5%76yZlsa}Z+0b+S!^5`zm-q$tyw5=#W}KZml&M5 zwh*do)WPQgmxNSU!=A+!ka6HWD`dC}-Zrj94U1Gz;nfCatzHAE94;f0KC2=7xhrU> zOThdgN%l$aAlrA<4vjXQV|C?3vQqb+aimW&t&K}wD0+F@}r0vok5J;Q} zPDkw-aTP@Z$9&L~i9KB7JRK-%c!7>x2_cR58_~LsJW{ZICsKL*6@BZc>^R@AuAJ4z z?mVnPvJTEBPHSx7?)nd??&d`F=1OP zBQbZJMB=={7+dj|XvrE!Q1MrXJ=a#UHP@Zs^yf=P<1Jo2iDWsEUgVQwKFI8FBibplD0^4|jc*HPf7n^El@nD!C`FNU`|h#~)XzZ4 zlkSqt?}E%UxJnE&+(BXKbdu4(6^`&-q5dxgE#D4DA^YisJhQxq3 zC0k}fcFYfE>&Q*wTpmI4qZ#&ytUa@(@H%t2+Za5Tx?5iG)<@f~jUo3r*O9T=3+{?v zW@uW-ZE{@aF`N4G4OdcQ7HMnH2H}uIR?c8LD2;yO3J9NMCLjMq+Jnqc>EVs2ZJ{ZM zIIaNk`bEqSt|7^pzZ~)|HlYV?FVG{q4sK`9ZSwJC59vhj$k_BSw0WL1Y#A&zDe?vc(4kk3sL}Zba+mwf z?)OP&wOU3<{*-1UTUg1WZ5Ln_ClmFp%4c-i&#`W2&oe>WHlpy%B&4>M2w${gX(l6_>rVV#JAl+ z9_7l!Qac_Qe|&>t7hFcNj#toJ#WU=B{@tiWxF89;eFb^T2(v-^Q&?x6QP#6c5<=T; zks&`{!AWNzQh6p@SbY|}qUN(n-r^``$Ol4$i%5-H0{eqO#O7BE+qGE{+Ek8$USt@z z#a@SKN&H32S7EgCx-c~H&+0-px*)g16ih0k$ei0&Ag~r{KITn>BQ6mL9}xnvD`(hD zr^o1)>5wJI-YrJMN*}pNcBRP6)(tJxyUD(7KShF9myqDo^~~18 zN+2URfo+}WZ|V5EnCzQY#SOoCnRqjsNTm2GGC|cAFz+e4TAxX-6zISaX%RBPuaD)$ zMZnu)dD0~F0A((ag?%s6n8oYF*h!=Mu&PN6O6I;}3O6{Csr7rp+wgJ^sI^+WK2aB(S+*)G*X&VdRaDppoJ9`R}x+n90Z_D?@ zE=N$X@Snf+pTG4#!ry9o;tZ-UWnLV@tAKC^NQm0{r^@p6Sk4 zA!=JB;Ff3zo1HMo#7OQ#Q$#diYmy>JOd4Q0^JU0|Vr?LbvEa$)j(w7Sg(jmY(lPrQ z3f_MThGlHXSH5TU?CmQ|tZWZ*Ymx%HWeRY-&jjr=zG1mxQ2_Eed>$fC$|4WJTWD!% zJ=?`KV>zR2jlj#{>b%wU#DAG2cj$I7`#LIvWJ+8nrQWHeP^6P9Z+0K;5P!f*-s67{ z`1^R_;#x8{Igwnya))(!)5ETl3}-?_JdpOn#bo80T`-|KfX#!mB=YxK=sjqQTIQL; zqgPTS+$NAcaFIir!Y?7!ub)tu;3Lv-xDq)nS;|d%C5`gXG1#;+i1b*T2l444q(zEh z$86-$iXb&~`QZn$?@=PNZo)Mt-(7-qshNO%QXt!SDx2|e_9L4H{K=Pr4Af$Hk;vUG zL&xuRAoSIqi4i)odQ2X%6fD?@@{{P437>a_)|p07ilYzz$;>|kY2y9TuN;>Dd=I+&-I$|RkeyxR^t=tAJoh*F0s1jm_=fY$*05;F^ zhlAJS=+Vq%O<{Ezx#4Jpbx{AiiZDu3OwsRCq6UZbq>84 zP7gQwQ?20^sg6(*eN7$rIr(mOZr(;3_`^IXym6Dv9$!q$)_Q|a#lGo${`>>CtcL$) z|6Ja?e~$iE^tYnF75%N~Z$*DA`diW8ivCvgx1zrl{jKP4MSm;$ThZT&{#Nw2qQ4dW zt>|w>f9wCl-}-Ode;(7je~$SrnBOuG@>|f~ivCvgx1zrl{jKP4wHV%P2-IVK%RtC) zL4PayThZT&{#Nw2qQCXWvGvr3`7I3tA-@Ivt>|w>e=GW1(ck)_{?>nE|6I|#e~$Sr znBOuG@>|f~ivCvgx1zrl{jKP4oilGwoHypT42b*|^tYnF75%N~Z$*DA`dh!t4ro|{ z`7Hw?zXkoR=x;@TEBaf}-}<8dR%!d^Ym8Q#$ztx^KbH^VlViUBByFJ^!!Uf;CJyFA zrXC*Lp{Qw@cnQhtQ~%PZ+Kg6{KEsrww{vHqGZ|*h&^!xbU-Hjdq&g->u9hoeqgC?A z=t#LLS{)O~{j652l?oMq6s=OKVpV*Vhp%0{7sl1ZcWvSzzS!P6bM8!hr+o;+@O=jV z3qf09OEp_`HoGw^%V^cvt$LGL$DMwP-o9d`F56^Zsk3C+O_@DqS&eB{qm6%`Blx1H zW~FC{dG%fG>j!ZRzP}}4Ww{)1t(*tH+Y>_D>t>DY>?2F>U;^$d;a5DN$|GX>u{m_rxaX zpVk&ExwaTFtu~-we*d)5D#lwz8z0bEFH z%au&y=5OJA?dtm$(iY@Nrm4j=SG#ut_G%qFh6aY9^G zrF2}M>BM!&32|+8(s6y}6W1{(#Em{L9oJ_&i`ttGX`*_c!YY z!4k@K958+lq~Q8r9BLEY-NT<#f(B3N<&AoN)bn_N`!~e5N8CB^je8xpreK(wr~LP4 Ui!5#85Ps>wjN-od=RbSypMgKTs{jB1 literal 0 HcmV?d00001 diff --git a/fme/core/distributed/model_torch_distributed.py b/fme/core/distributed/model_torch_distributed.py index 731cdf2dc..e2695c2c1 100644 --- a/fme/core/distributed/model_torch_distributed.py +++ b/fme/core/distributed/model_torch_distributed.py @@ -25,6 +25,7 @@ import torch.distributed import torch.nn as nn import torch_harmonics.distributed as thd +from torch.amp import custom_bwd, custom_fwd from torch.nn import SyncBatchNorm from torch.nn.parallel import DistributedDataParallel @@ -42,6 +43,35 @@ T = TypeVar("T") +class _AutogradAllReduce(torch.autograd.Function): + """Autograd-aware all-reduce (sum) for spatial parallelism. + Forward: all-reduce (sum) the input across the given process group. + Backward: identity — gradients pass through without communication. + This makes ``spatial_reduce_sum`` differentiable so that gradients + flow correctly through the loss computation path:: + AreaWeightedMSELoss → area_weighted_mean → weighted_mean + → spatial_reduce_sum (uses this function) + Without this, the raw ``torch.distributed.all_reduce`` would break + the autograd graph because it is an in-place, non-differentiable op. + """ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward( + ctx, + input: torch.Tensor, + group: torch.distributed.ProcessGroup, + ) -> torch.Tensor: + output = input.clone() + torch.distributed.all_reduce(output, group=group) + return output + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output: torch.Tensor): + return grad_output.clone(), None + + class ModelTorchDistributed(DistributedBackend): """Distributed backend with spatial model parallelism. @@ -307,23 +337,65 @@ def _device_ids(self) -> list[int] | None: def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: """Wrap with DDP over the **data** process group. - For now, we assume spatial communication is expected to be handled - inside the model layers themselves. If we need to change course, we - can revisit... + Spatial model parallelism is handled by: + - Forward: communication inside model layers (distributed SHT/iSHT) + - Backward: gradient hooks registered here that all-reduce across + spatial ranks, so every rank sees the global-mean gradient. + + ``broadcast_buffers=False`` is required because the SHT/iSHT layers + store precomputed Legendre polynomial buffers. DDP's default + buffer broadcast modifies these in-place between forward calls, + which breaks autograd's tensor-version tracking. """ if any(p.requires_grad for p in module.parameters()): if using_gpu(): output_device = [self._device_id] else: output_device = None - return DistributedDataParallel( + wrapped = DistributedDataParallel( SyncBatchNorm.convert_sync_batchnorm(module), device_ids=self._device_ids, output_device=output_device, process_group=self._data_group, + broadcast_buffers=False, ) + self._register_spatial_grad_hooks(wrapped) + return wrapped return DummyWrapper(module) + def _register_spatial_grad_hooks(self, module: torch.nn.Module) -> None: + """All-reduce gradients across spatial ranks after each backward. + + Each spatial rank only sees its local slice of the input, so its + gradient is a partial sum. This hook sums those partials so + that every rank applies the same weight update. + + The hook fires via ``register_hook`` on each parameter, which is + invoked with the per-backward gradient tensor before it is + accumulated into ``.grad`` and before DDP's data-parallel + all-reduce. The two reductions commute (orthogonal groups), so + ordering does not matter. + """ + if self._h_size <= 1 and self._w_size <= 1: + return + spatial_group = self._spatial_group + + def _hook(grad: torch.Tensor) -> torch.Tensor: + if grad is None: + return grad + + reduced = grad.contiguous().clone() + torch.distributed.all_reduce(reduced, group=spatial_group) + + # If we want mean gradient instead of sum, we want: + # reduced /= (self._h_size * self._w_size) + + return reduced + + for p in module.parameters(): + if p.requires_grad: + p.register_hook(_hook) + def barrier(self): """Global barrier across all ranks.""" logger.debug("Barrier on rank %d", self._rank) @@ -331,7 +403,7 @@ def barrier(self): def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: if self._h_size > 1 or self._w_size > 1: - torch.distributed.all_reduce(tensor, group=self._spatial_group) + return _AutogradAllReduce.apply(tensor, self._spatial_group) return tensor def weighted_mean( From 2f85d7a5f32c022dc744e6aa1d2cce66f2b6a57a Mon Sep 17 00:00:00 2001 From: peterdschwartz Date: Tue, 10 Mar 2026 16:20:26 -0400 Subject: [PATCH 3/4] Implement an autograd aware all_reduce op that doesn't double count in backwards step Add test that verifies consistency between NonDistribute and TorchModelDistributed for loss and gradient calculation using simple SHT/iSHT transforms --- .../distributed/model_torch_distributed.py | 20 +- .../parallel_tests/test_backward_step.py | 211 +++++++++--------- .../distributed/parallel_tests/test_step.py | 147 ++++++++++++ .../testdata/backward_step_baseline.pt | Bin 0 -> 18461 bytes scripts/testing/test_spatial.sh | 27 ++- 5 files changed, 287 insertions(+), 118 deletions(-) create mode 100644 fme/core/distributed/parallel_tests/testdata/backward_step_baseline.pt diff --git a/fme/core/distributed/model_torch_distributed.py b/fme/core/distributed/model_torch_distributed.py index 7f84e96b4..2dfe65371 100644 --- a/fme/core/distributed/model_torch_distributed.py +++ b/fme/core/distributed/model_torch_distributed.py @@ -23,7 +23,7 @@ import torch import torch.distributed -import torch.distributed.nn.functional as dist_nn_f +import torch.distributed as pt_dist import torch.nn as nn import torch_harmonics.distributed as thd from torch.nn import SyncBatchNorm @@ -43,6 +43,21 @@ T = TypeVar("T") +class SpatialReplicatedSum(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, group): + ctx.group = group + y = x.clone() + pt_dist.all_reduce(y, op=pt_dist.ReduceOp.SUM, group=group) + return y + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Important: do NOT all-reduce again here. + # The forward result is a replicated view of one logical global value. + return grad_output, None + + class ModelTorchDistributed(DistributedBackend): """Distributed backend with spatial model parallelism. @@ -332,7 +347,7 @@ def barrier(self): def spatial_reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: if self._h_size > 1 or self._w_size > 1: - return dist_nn_f.all_reduce(tensor, group=self._spatial_group) + return SpatialReplicatedSum.apply(tensor, self._spatial_group) return tensor def weighted_mean( @@ -342,6 +357,7 @@ def weighted_mean( dim: tuple[int, ...], keepdim: bool = False, ) -> torch.Tensor: + from fme.core.metrics import weighted_sum local_weighted_sum = weighted_sum(data, weights, dim=dim, keepdim=keepdim) diff --git a/fme/core/distributed/parallel_tests/test_backward_step.py b/fme/core/distributed/parallel_tests/test_backward_step.py index dd8467128..66beeff55 100644 --- a/fme/core/distributed/parallel_tests/test_backward_step.py +++ b/fme/core/distributed/parallel_tests/test_backward_step.py @@ -1,140 +1,135 @@ +import pathlib + import numpy as np import pytest import torch -from torch import nn import fme from fme.core.distributed.distributed import Distributed -from fme.core.distributed.model_torch_distributed import ModelTorchDistributed from fme.core.gridded_ops import LatLonOperations -from fme.core.optimization import OptimizationConfig from fme.core.typing_ import TensorDict +DATA_DIR = pathlib.Path(__file__).parent / "testdata" +BASELINE_FILE = DATA_DIR / "backward_step_baseline.pt" + -class TinyConvNet(nn.Module): +def _run_forward_backward( + img_shape: tuple[int, int], +) -> tuple[torch.Tensor, torch.Tensor]: """ - Very small conv net that operates on [batch, channels, nlat, nlon]. - This is just to ensure gradients propagate through a nontrivial model. + Run a single forward + backward step and return: + - loss on this rank + - global gradient """ + dist = Distributed.get_instance() - def __init__(self, n_channels: int = 2): - super().__init__() - self.conv1 = nn.Conv2d(n_channels, 4, kernel_size=3, padding=1) - self.act = nn.GELU() - self.conv2 = nn.Conv2d(4, n_channels, kernel_size=3, padding=1) + batch_size = 4 + n_channels = 2 + nlat, nlon = img_shape - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.conv2(self.act(self.conv1(x))) + # Build 2D weights with correct spatial shape + lat = torch.linspace(-np.pi / 2, np.pi / 2, nlat, device=fme.get_device()) + area_weights_lat = torch.cos(lat).clamp_min(1e-3) # (nlat,) + area_weights_global = area_weights_lat.unsqueeze(-1).repeat(1, nlon) # (nlat, nlon) + global_ops = LatLonOperations(area_weights=area_weights_global) -def _build_latlon_ops(img_shape: tuple[int, int]) -> LatLonOperations: - """ - Build LatLonOperations with simple area weights. - In spatial-parallel mode, this will slice per-rank tiles internally. - """ - nlat, nlon = img_shape - # Use cos(lat) weights (approx) just to be realistic; could also use ones. - lat = torch.linspace(-np.pi / 2, np.pi / 2, nlat, device="cpu") - area_weights = torch.cos(lat).clamp_min(1e-3).unsqueeze(-1).expand(nlat, nlon) - return LatLonOperations(area_weights=area_weights) + # Global tensors + torch.manual_seed(0) + x_global = torch.randn( + batch_size, n_channels, nlat, nlon, device=fme.get_device(), requires_grad=True + ) + y_global = torch.randn_like(x_global) + global_inputs: TensorDict = { + "x": x_global, + "y": y_global, + } + local_inputs = dist.scatter_spatial(global_inputs, img_shape=(nlat, nlon)) -def _build_model_and_optimizer( - img_shape: tuple[int, int], -) -> tuple[nn.Module, torch.optim.Optimizer, LatLonOperations]: - """ - Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and - a simple optimizer. Also returns LatLonOperations for computing a global loss. - """ - dist = Distributed.get_instance() - assert isinstance(dist._distributed, ModelTorchDistributed) - - model = TinyConvNet(n_channels=2).to(fme.get_device()) - # Wrap with DDP over the data group only; spatial model parallelism is - # handled by the model/layers and the backend. - wrapped_model = dist._distributed.wrap_module(model) - - # Simple Adam optimizer via OptimizationConfig, to go through the same codepath - # as real training. - opt_config = OptimizationConfig( - optimizer_type="Adam", - lr=1e-3, - enable_automatic_mixed_precision=False, - ) - optimization = opt_config.build( - modules=torch.nn.ModuleList([wrapped_model]), - max_epochs=1, - ) + x_local = local_inputs["x"] + y_local = local_inputs["y"] + x_local.retain_grad() + + sht = global_ops.get_real_sht().to(fme.get_device()) + isht = global_ops.get_real_isht().to(fme.get_device()) + # Forward: x -> sht -> isht -> y_pred + y_hat_local = sht(x_local) + y_pred_local = isht(y_hat_local) + + mse = (y_pred_local - y_local) ** 2 + # Global, area-weighted MSE over spatial dims via LatLonOperations + mse_spatial = global_ops.area_weighted_mean(mse) + loss = mse_spatial.mean() - gridded_ops = _build_latlon_ops(img_shape) - return wrapped_model, optimization, gridded_ops + loss.backward() + # Gather grad_x back to global grid + grad_local = x_local.grad.detach() + grad_global_dict = dist.gather_spatial({"x": grad_local}, img_shape=img_shape) + grad_x_global = grad_global_dict["x"] + + return loss.detach().cpu(), grad_x_global.cpu() @pytest.mark.parametrize("img_shape", [(16, 32)]) @pytest.mark.parallel def test_spatial_parallel_backward_step(img_shape): """ - Test: run forward + backward + optimizer step under + Test: run forward + backward under ModelTorchDistributed with spatial parallelism. Asserts: - - Loss is finite. - - All data-parallel ranks see the same loss. - - Parameter gradients are finite and data-parallel-consistent. + - Loss is same with sp decomp compared with NonDistributed baseline + - Gradient is element-wise same with sp decomp compared with NonDistributed baseline """ dist = Distributed.get_instance() - if not isinstance(dist._distributed, ModelTorchDistributed): - pytest.skip("ModelTorchDistributed backend is required for this test") - torch.manual_seed(0) - model, optimization, gridded_ops = _build_model_and_optimizer(img_shape) - - batch_size = 4 - n_channels = 2 - nlat, nlon = img_shape - - # Global tensors - x_global = torch.randn(batch_size, n_channels, nlat, nlon, device=fme.get_device()) - y_global = torch.randn_like(x_global) - - global_inputs: TensorDict = {"x": x_global, "y": y_global} - local_inputs = dist.scatter_spatial(global_inputs, img_shape=(nlat, nlon)) - - x_local = local_inputs["x"] - y_local = local_inputs["y"] - - # Forward pass + loss - model.train() - optimization.optimizer.zero_grad() - - with optimization.autocast(): - y_pred_local = model(x_local) - - # Compute a global, area-weighted MSE over [batch, channels, lat, lon], - mse = (y_pred_local - y_local) ** 2 - mse_spatial = gridded_ops.area_weighted_mean(mse) - loss = mse_spatial.mean() - - # Backward + optimizer step - optimization.accumulate_loss(loss) - loss_before_step = optimization.get_accumulated_loss().detach().clone() - optimization.step_weights() - - # 1) Loss finite and the same on all data-parallel ranks. - assert torch.isfinite(loss_before_step), "Loss is not finite on this rank" - - # Reduce mean loss across data group and broadcast to root for inspection. - # ModelTorchDistributed.reduce_mean reduces over data group only. - loss_reduced = dist.reduce_mean(loss_before_step.detach().clone()) - if dist.is_root(): - assert torch.isfinite(loss_reduced), "Reduced loss is not finite" - - # 2) Gradients finite and consistent across data-parallel ranks. - # For a DDP-wrapped model, parameters are identical across data group, - # so their gradients should also be identical after backward. - for param in model.parameters(): - if not param.requires_grad: - continue - if param.grad is not None: - assert torch.isfinite(param.grad).all(), "Non-finite gradient detected" + # Run forwards/backwards + loss, grad = _run_forward_backward(img_shape) + + # Only root does I/O + if not dist.is_root(): + return + + if not BASELINE_FILE.exists(): + # Baseline generation mode: expect non-distributed backend here. + # Save loss and grads for later regression. + BASELINE_FILE.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "img_shape": img_shape, + "loss": loss, + "grad": grad, + }, + BASELINE_FILE, + ) + return + + # Regression mode: compare against existing baseline. + baseline = torch.load(BASELINE_FILE, map_location="cpu") + assert tuple(baseline["img_shape"]) == tuple(img_shape) + + baseline_loss = baseline["loss"].item() + baseline_grad = baseline["grad"] + + # 1) Loss finite and close to baseline. + assert torch.isfinite(loss), "Loss is not finite on this rank" + + # Compare loss (scalar) with a small relative tolerance + actual_loss = loss.item() + rel_loss = abs(actual_loss - baseline_loss) / max(abs(baseline_loss), 1e-12) + assert rel_loss < 1e-6, ( + f"Loss deviates from baseline: " + f"actual={actual_loss:.8f}, expected={baseline_loss:.8f}, rel_diff={rel_loss:.3e}" + ) + max_rel = ( + ((grad - baseline_grad).abs() / baseline_grad.abs().clamp_min(1e-12)) + .max() + .item() + ) + assert torch.allclose(grad, baseline_grad, rtol=1e-6, atol=1e-7), ( + f"grad_x differs from baseline: " + f"max_abs={(grad - baseline_grad).abs().max().item():.3e}, " + f"max_rel={max_rel:.3e}" + ) diff --git a/fme/core/distributed/parallel_tests/test_step.py b/fme/core/distributed/parallel_tests/test_step.py index 6800d0b83..b175c4499 100644 --- a/fme/core/distributed/parallel_tests/test_step.py +++ b/fme/core/distributed/parallel_tests/test_step.py @@ -30,6 +30,7 @@ from fme.core.distributed.non_distributed import DummyWrapper from fme.core.labels import BatchLabels from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.optimization import Optimization, OptimizationConfig, SchedulerConfig from fme.core.registry import ModuleSelector from fme.core.step.args import StepArgs from fme.core.step.multi_call import MultiCallConfig, MultiCallStepConfig @@ -467,3 +468,149 @@ def test_step_regression( output = dist.gather_spatial(output, img_shape) cache_step_output(output, DATA_DIR / f"{case_name}_output.pt") + +def _run_step_optimization_backward( + img_shape: tuple[int, int], + n_samples: int, +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Run a single forward + backward through a Step using Optimization, + and return: + - scalar loss on this rank + - gradients for all parameters (CPU tensors) + """ + device = fme.get_device() + dist = Distributed.get_instance() + + selector = get_single_module_noise_conditioned_selector(None) + step = get_step(selector, img_shape) + + modules = nn.ModuleList(step.modules) + opt_config = OptimizationConfig( + optimizer_type="Adam", + lr=1e-3, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + ) + optimization = opt_config.build(modules, max_epochs=1) + optimization.set_mode(modules) + + # Global inputs; Distributed backend decides what scatter_spatial does + input_data: TensorDict = get_tensor_dict(step.input_names, img_shape, n_samples) + next_input_data: TensorDict = get_tensor_dict( + step.next_step_input_names, img_shape, n_samples + ) + + # Use the same scatter pattern as the step tests; in serial this is a no-op + input_data = dist.scatter_spatial(input_data, img_shape) + next_input_data = dist.scatter_spatial(next_input_data, img_shape) + + # Forward + out = step.step( + args=StepArgs( + input=input_data, + next_step_input_data=next_input_data, + labels=None, + ), + wrapper=lambda x: x, + ) + + # Use the real training loss from the step output + if isinstance(out, dict) and "loss" in out: + loss = out["loss"] + else: + raise RuntimeError( + "Step output does not contain 'loss'; " + "wire this test to the real training loss output." + ) + + # Route loss through Optimization, but only run backward (no step/zero yet) + optimization.accumulate_loss(loss) + total_loss = optimization.get_accumulated_loss() + optimization._backward(total_loss) + + # Collect parameter grads + grads: dict[str, torch.Tensor] = {} + for name, p in step.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().cpu().clone() + + return total_loss.detach().cpu(), grads + + +@pytest.mark.parallel +def test_step_optimization_backward_matches_baseline(): + """ + Regression test for Step+Optimization backward under spatial parallelism. + + Uses the same forward/Optimization path in both serial and spatial + backends; serial run is used to create the baseline. Spatial backends + must reproduce the baseline loss and parameter gradients element-wise. + """ + DATA_DIR = pathlib.Path(__file__).parent / "testdata" + BASELINE_FILE = DATA_DIR / "backward_with_opt_baseline.pt" + + dist = Distributed.get_instance() + torch.manual_seed(0) + + img_shape = (20, 40) + n_samples = 2 + + loss, grads = _run_step_optimization_backward(img_shape, n_samples) + + # Only root rank writes/compares baseline + if not dist.is_root(): + return + + if not BASELINE_FILE.exists(): + BASELINE_FILE.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "img_shape": img_shape, + "n_samples": n_samples, + "loss": loss, + "grads": grads, + }, + BASELINE_FILE, + ) + raise AssertionError( + f"Baseline created at {BASELINE_FILE}. " + "Re-run the test to perform regression check." + ) + + baseline = torch.load(BASELINE_FILE, map_location="cpu") + assert tuple(baseline["img_shape"]) == tuple(img_shape) + assert baseline["n_samples"] == n_samples + + baseline_loss = baseline["loss"] + baseline_grads: dict[str, torch.Tensor] = baseline["grads"] + + # 1) Loss finite and close to baseline. + assert torch.isfinite(loss), "Loss is not finite on this rank" + + actual_loss = loss.item() + expected_loss = baseline_loss.item() + rel_loss = abs(actual_loss - expected_loss) / max(abs(expected_loss), 1e-12) + assert rel_loss < 1e-6, ( + f"Loss deviates from baseline: " + f"actual={actual_loss:.8e}, expected={expected_loss:.8e}, " + f"rel_diff={rel_loss:.3e}" + ) + + # 2) Parameter gradients match baseline element-wise. + # Require same parameter set as baseline + assert set(grads.keys()) == set( + baseline_grads.keys() + ), "Parameter set changed since baseline generation" + + for name in sorted(grads.keys()): + g = grads[name] + g_ref = baseline_grads[name] + assert g.shape == g_ref.shape, f"Shape mismatch for grad '{name}'" + diff = (g - g_ref).abs() + max_abs = diff.max().item() + max_rel = (diff / g_ref.abs().clamp_min(1e-12)).max().item() + assert torch.allclose(g, g_ref, rtol=1e-6, atol=1e-7), ( + f"Gradient for '{name}' deviates from baseline: " + f"max_abs={max_abs:.3e}, max_rel={max_rel:.3e}" + ) diff --git a/fme/core/distributed/parallel_tests/testdata/backward_step_baseline.pt b/fme/core/distributed/parallel_tests/testdata/backward_step_baseline.pt new file mode 100644 index 0000000000000000000000000000000000000000..b876c651eb6a7f46359e60f978ad8fddc8425657 GIT binary patch literal 18461 zcmbrldo)#F_&&OGPe^i0DoGM0*S+SPE=iJ7l2nr1l|&&)cOv(@E|P?VkQDZsa}tsy zNkURdl1jP}NvMv``JMCE8Q(L0XPh<0JN6iR&pn@KuXnFC-}OB2w6^3I;BZ7lIRB4h z0cSPG)x~Y!VHbaQr-K0=zD}+#2R*#^?)O;X?h@dl>ATO{+;kdK z3VyH89;btzF1{Xq(<~(|mBRh_?53!3ybl~a=*RCCaKPWqQ`6~Cz+UfzqE7xEu7~z| zyE_GV>_2$G-ziW>)K5T7%}#h~{K2V_E_*!u1pn7pM9sYqxCCtZUvG-~3E54X8l&YW zY{x(K>gIdMPsDOsfZz1ZmYk(aquksMczb)e1?)Yr|Dfn5e|Ham4|lV@ZULfxqDyW3 zW(4?&+4zb7pRD%yySV#Fcus+O3VQypNwxiEdWvl1NCx;xPR*ufDPYO>-|qZ&+tn*YX2kr|8>%xT)g-0+3(}AKfv5n z<$sF-SMNUpD-Zj9;+q2hFaBo&N*t~KiGb^o0FMLz<;&c3+5Z;q|NXD^-Ww3$?eQOn zGe+;`2uvp7`3?nE=10M@&AUd{r`%t zZEm{!{|maky}QW>JD8({`SMk9zla!KkXFKW&Unnbd}D(9VOcD56cjMYsu|2dKX1lZ zN|?#V`dsJjg=I<)E6U`9`}NJ}UT*4Vey&6Ic^*(W!oEppW6vM?%jP_N$v*8dXC1?& znElIsGQzXm;|E9rU=y6EMhQ9K>HUUE)UI;oMJP{JkCY zF*^e#LN`JCr`Le8kcb=^iQ|CnO{byKd12mlH$E!<2hqS!6D0HG7e+DQB{X*Z0Qv30VfD2% z5SVcv3kb!L81^dJvppY|v~Gf5&RD}yO#>jtJI5G*d<%|SDnKPRlotugiFb}X-PZq> zC)ugniL>by z5GEyx6y$lZ=F~x~etQ+Hj30&rit(VW;V}_5v!QBL88q<&W(zjFgatSFpoIQulppqh ztdsgelCo=&d*EWqDgH>$XD5*U$9u7S#4Kpxmq`MhE$Hd&P=Zgs#Ze#5L5{u%w{^gP zO58aNz2z;C*iaGEx3C^1Y&Cr0#+_(y-_nkwzs;iM%$W}7FCyVCkRU;X>3y6R97pQT0oD_>c zhbp(~kmK7SGBzPXl=7Z}#zqzT>|qJrR~U)j)?7z--VKs1D_5e=$*m;6`U3j%!-l2= zG{IPNAE@jULd_llG=KdZB6uQ_X2kv@52D6MM`H=?`)i2&U;2=q)-G7&`w1RDQ%Kpf z@~AsCp5BfyAp67rk>l@tNY~LOv}40KO%Ic;kXTqq=e-;wrfoUIv+q8f)ig#Xdpqcw zzrFB4a~F+K5T}l>&Eb~S*NKT;KGEeRkOtosa87F&()A2OHOU`A&^Agkm6O1O?^E{% z^{d3^h6Vay=|D;suBRW;e367!1g#7@fmF6A8%SF@qirprXg-ocUDE`Ky~;eQxp@NJ z8JliUdEK8RFgs~|#6?;>v=CkLo`cM7V_;F;J6Ke89-3@>$nJk}7xsy%pxsLspxL@M zs4hMn*c*3w(KxDQc+IV{z za;ZFouH48(dWYuX`AW5?Kn8ihzCvoJR3~=#9E^N5RLC078 zhC0)l;p3qq=>Nh1Oy~yOpRLQB?T`iOcb9{- zzSTfIkLY?<;gSW*y{F9H zb*DUmeTn7mfihgfv;^)9|JU5GHir59FPU-keaMWif5Y6~MVYAOnv8>16K_>?1$$Lh z9FHr{#R6%9_^(GEOXpqSJ;>B%{!J@q<~hD#YQNlNlADh+eamlhm$|LuX^QZ(qp1to zt=Bc!Kqts%7XIaVT#aDUdb?PMRxvzsNF4V{wXt?bmDw}mVvM2wLuTl>82EW*HV_*a zWx5WgGQYddaie$6U^82i*p9GL_Fq;JE6+x-(wCK36K+@eT%j*qi5VeG-SRRo-kJRs{w`lu%Su?PG&ZO$hQ&fypJO|X;_Jrt!u$v`?45b5eN3vRk1bjA71}r z6&ZhAOAPMt(IyLb(iB_^x2{%4No}X#9{ika@KvTCavzh!oe9KrS`Nu>nMFh#B;epu z1!yhym=USC3AUe@3nR-pq%Y(Rsj59ddS#YFQSMJ5Wu;D>gZ^Q!hKFFbybLrus}AcP zghKsO2<8@-lMP4o;c}T_cvMvsCeOG_6gS-?4<4Gs?j9vH*Df8tTrEj&?B7Hmc12@V zPbb(i?nS9x=g1J*Rhv5O4Cp86h6eB4K$?(Lx>o-7JhJdYZZeo>z6Wt8Tai-Jy_ zhLT%$(-7Vh`r?!x3W^|j_pTz!$NfR18Y50?0rO#ZH(KKehV^Z~sPD`DoE5M-hZ~pF@?8LWnTTf(Cc@AccaZbnwG> zm@ZL3uP~w1=;j>c7&c73tS`{#i!}_S!!^;DbV2HS9il-`29xdAkk`BWaKVyBzL>k>%nnxYG-^2c2Jk+owvSa9Kria`M45VUbOCTpv8~3D%kQ0S73`~sAGu9(_(W9slYQttw zlMfG>)!X~YwBykzxgDX&m$DTnA6=kfTKu%?Wid6+*^K_!nJPmpaqY2KP#of5sjT(^%xt1nK?B zCL)lT0p*%5kb)5*B+bNjrOQV!SEby z+H%8}q;y;W7pmWbxyz403kQn7)y$w8IV15=1 z5jlM*jqEQkf-T}oP;rJD)jU#7O1FnWYtj0tbzc%bnzb7_zUQLTeLA$)F@tI?4kgC5 zcBqs&hRinf!n)=2q5RsJNa0ZcYW{=aN-JCP=|=$!yZs4%`@59#-HxOIZT>{={axIu zJ%iditHS(?a;V35T$kM~May=j7!>{6nQu5Zb+YqIdH zZwpk7Hv+0%`^iX0Cv!Ek5C%V&gJ~}>k$mz@LH zrIZ-`Z!<~V6mIS^`h2a}svq=9gO%kf8Kn|RHfO}I8;JAo$ zcntHRgl-8Hc!+YHqSI{95eF&Ty{1d@kc?ECGg6 z^uhkwE5ZFyRUnuBj(Opv#i+SH<{6J?vGB!vc2?#z{9u~^1|b}r=sv*OatNEY(VQ(R zk>=v!B<4WF7_)s;8mNef0d=V+#$}ojQ?_ymFXx9pyE?Rv-4B1T0sA<3Tk=TBv#~8H)|CvkGsqC zu>K-FT$Cb)O-zom^EwtYk`po@ali*m&y53xKO@0Pbz30s^N|VMYsqa>h+wz$KV_?K zSFtNiHn4Tavh~0JSi>}_g)?3s4B_#Q958E3K1M!&urgDJ4YOO=*pD`#CgcZ5n>2z8 zwuM6TSr_4no(MR&Q5vT1H)N}lCh!j*De}pUV9R%r>^2o`!1c)k*yJ8~@HrB6rd6@{ z%`%b~okPw^JSAy+?vfjyyvT@*K8r0kK-s%j;fm|&kZ)oOQPp)Jk{cJ1274*epp=XU zt<~W7M|p68tt<5OO99qyeV{oh1ePwk4m*FJg^ZvBd~X;7y-Wm=8P5WpU1E<+9;>6* z9(gb^Sr7PLoWUIEGKH$(BJ6Y@2hvhJ!9A z5D;IF-Qc!(90<=4hq0Ch$lAplX+1QC#&^#WZ?ukdzE39cKV)fSz68Cgqe^1s<)PAm z84Qlx#Z{dd4aR=_hEa>Bp*p`V&^hxeD|M<0{&tv&ezvNUGsKiOHdT|ssD5^#UME&g znogYi%E^NT3u*j8BlLdB0yO?_9&j!lhP4mX5tuvx{g-T`KDR1puD>chTbl1x|sz_2D2S~RQzO{@B$Qp;wvNwEjo_VkhI zQqIsxpb2gE3r3H(Q4D1{b5Z!qw28lK4qANG&QCAr*H5rKjBT@Ev zTUnv{)U*i&_COZU;CqCfu3sFeA{R}o-;z|ug!_i;K2Ppi` z8Fbp0V{qq#C!HJ#BkdRY44fsbk$nCcyzb*IG&naOhNT}Nld^i)?VKI`<|0U+ZESa(@}w&Ju3CR3T3}3(BYwA__JC7b?)+}(VIqyo=_YeT;fEV zN{68D4nCl#-bdbEdH`i!*O6LvM03PF;QLz-5HEQ-+PyE9q`nU(l7TX)OmZ1|*1wwi zR#sD6mv{JFK^!VOy%2+fWWwE^4I8beQIP$E#9A9cv;9}$r>ik&?YYA=EASFE+q0Yw zy(+`ZPzI7#`~_zQzbEp&sd#STQChP)pKdNwpx2KlQPIl_sHYf0iIMV%md`~}dlFGg zcqQpxe}OK$>IU12FCejUfF$@=pkq7gk?Ed3Vi7$|BUdXUCteg15nB!21uMv;ZVx?O ze}Ps#8X=;#negbgA3(R}JnG>WGT@8^!O_C~baclOD!u0w`Rm&OMSCUCS|CFIJo`l? zPwWGkZ8rE~@C7)c{QwFxOK5i;OTWJHq$}>)p=||aC@ebz_vJ67buFTF=}%b-&IObE z{i1Xnh`>Y_E*j$nQ6uRq)VkFXy6@Ww3r?kw_SK)Dj+YZsQcnk9X)7VqqKI>xE-q?_ zfs^fJ;IyqG?Y?G0m6VG}|6Y4sk{AmRx(1Ig)k8}I6_A9=MD(h~8A`*$r?yNsYW_b}{8G=~}B zG59;n01Vb01Ho}sz;tFBz})YQdu9StvYTac(kmG2oD)p&#qiR`mznIFP-*P;cLkm- zorMd2?_p!ZqnVBq8^QFN%OE}x1LM!dK>vIQ0IDCEOClRt>$pw0ZgCc_?xa{J`W9CD z8I4^F)o|ldAJ)729=9Vej%hc(#w623=CZC66HF1qpJl_~y*4jd8XA`J7_`{JAH zxj1p-Bews76HjD)BvUg}6nyz?0*Y1K!G2pWz<=8U1XqnRCG*U=?YNUQ%(TYm*M;KY z@L;?(!W47Amb16pK65wz69ZqGDz_sU?kOtJEw(_MPo0>V*W|exVM>%B)-P47KA`QoB)LsEb+zd ziKL+|iTo6aAc@zP60_?!S>fy~9RFRJG&F5v1TAD>=3jH>`0!=iQ7MFjCKbq*OL;^I z2a@N(*_ceNex0<+(HzWdAKP;VztnVm^iT9v~K)ob86hp$BOP6pYv?-?;4PNlD} zhf|YS4#>Ki0?*I83G%r7Xsme=+Tdpm!*$k>C-EY5&6XcTCHNg4ty3nE@$N)(!)GYy zH3%&;?i1kqgZzl~Ba3u{$zjQ_xM!CP5<3(NT9Wi}&Yd*$qWTq->YGpAQ(**-=HQys z!KD0n4$!yXf^Z zh2#g_gzbgJ2){cI2BqIao=XSdmF@S4(!N5lsBZ#(xpRiT$V!2G4o=lU^B>@PgFXnN zT!G?=D0ss80!hl)K$i)B18X$iK{Jb^bVGmuUA`cLNRM8pc`Lq>@UwMj!`dBa(|vFH zwmk*eHQz#iEroh1bO|-&cY=TtS2|jA6Gfya zq4CphWHk6WDs*i}IXN-JL_@+L-)1f9cV9p!jBg_qaR@(@G1PL0GUb1Lo|bJ2r`8gI z^eK*^vAKt7ZpSfXb87`caD{4^qAxd1$@=PHOei7}+oFqpnuU z=(m>HQ! zP1x2T1)9gQv|bp~&PoyTBW6F!6bnN$r6ddj_S_+|cMa&Wftj>y%0sC!bH|D^PC&2k zE>N>{27O%q5{(4UBAjIrAXxh!&93L8LUc9>Ke2`8FE^&ue0{WFP8=0;iJ~nYJZhoN zP*1A_l5v_TJQ0+IzfO9E?{Sa$NqRTd+lUZ|7;{8`hAoME?v2H-J79MmicSglBEi0Qz*m=foA8uaqX7~0t!4_h;{VXxF(dTjOHDPJoC<*t>bH7n22b}>uhKXD3i_DT=~ z=WuenN1rBoNmBcJ)6wM4K{DppLdwPpp`zMS8b0ZdIMbga+gx*cXW}FEkdUbmY#5@M zEjmzXs{;6_6HWWhe+0pMxoAUl9mzSoj_!Q?kXEI2cJRk%sgr@tmlbPb72Izv#)H?s9w0<`L~L(cdmd);u1v|a;5sPs1Z69I_0>N)uR zyYN(vu8Zt4X2Tt;8%h4#-MB30GKumJp}LnDD*PvbWZwSBs>^Dj`jd0emx~;ldd`@7 z##SwrG5u_m#zwY`7sbs|&t*gnML_2-6`(Jy1lGHMVeF%I znfuu~?DZ#q+1p!nv56+clfAR>^ZA!qf!Br12fqu9x4IP*yDgNvY_goE9k!pn7+S^_ zGH+NZ)$i>5!!KFmf0tNqpU=GX;zq9R&lpBx+AT)1FP<3~$l@A5ux3w={AN`itiXoS zh8Pzs;%MPg)^j|oypBy_@?<#R#xH5W_g4(u>HfkPmew-5pqTm7Sj5DwsRPA_`+!u> z8DLiai<$H>0~}NZTGL*FK;a!Ab)O9e_vMJ>6*Dr@Zb<%E%8`2QD(v{<6Z`UJCzE@? z5!}f_z+~Ef?x>+DKAL(Na|GOauLI7Oa5s0tAd8A&=nCSTZ z!-Fq3@^Dr4L}z? z1Rs3RBt0=9r1J7Al30D2YpoLp$5XCBk1|8}=bQ#ST&WG$x7p%aa*ND&>mcUyFOkai zO9?hlV9Rn!fKu2OhHueT?C>-f`)}Vt9Ah4n*UTMaE^rj4d+$IUbH2bL=tN>aqs$84IH;qb3QqCgcaqSPW3TeY(!CZ9K_B3=>Jk3PwZ$rz@ zE{405%gOr%b+CT2oCthg22M3*(2T2~V(|D~s*x57!zQ?huhIg!nNCBF5i@A>Ms*^( zdKTjS4uxWmuTnm`hswTZ;F_X!$nAI?0+n;Iq;W1Dd-a(amW!uin--F;O{O%=y^u!t zpC|dNW+1($YPbl_M3IkYg27c2Fsbb?+Py{EA0ML|Sks4NmyTet#ZM)fOewqd6y$ z*Be8Fol3W zzZpqE*%$Ci)tA~DZy;s3mFz2iLro@M(SVipQ~erc=tXf6lI$!3ca)zKkK969pdAL) zZMvsC2W47vq@0-eUg8~oFG7n47gcaVN*UiNzr{ahKGj*VkyKvl$EM?tQ1bDIWO~d$ zqPFK7Ii1}`K8u>t=39l(nk+?)AFGktj0;pL*`M*}pN>vT4p93Vd)jQshZ^Q}pg;BR zU=-_!;$LY{`Ig(1k8hSi-=<*@;L8siV>Y7z(K0;QKSH~Xi&uOz5u*WEr$T)BE^573 znG~Zsf(%xL)-I9cAm3)z{PX}!8xd2x_*M>KGd8EgkLv19o;p#7EKxSbc zwM{riJ(HYZtDY!bXz_}C7mI^)>o(I_!V77c`ch|5cpU>5$TwyLP3OWd{xmY6k)}OGkJ^{}Uoklp3fN9)&izKgHOVrd4BOCKuNQg5D z3p}HMK~o{jcz+wWpYA19dyhcwBo{X2N7CZ`iFEK}1$n&u0Xh9W58oCpf|(ksRMGkd zJ=s}G&RYG2Uv$!7)h0fecPba`-J#da*P>jL5HDe448Ty2Cw&C1^$7j zuw(xma<5Q^T_Gv@p&auQLsa>lseaBYkWt&yFsOW!*p~TlhMJ z^`73vn@oSi{r2ZN^X2;t@Yg{b@V|qgEMEz%oNQ-`JS3U+R2wzwTzl9$Vw_)rUriF$nZ7k`i#WEc-xU*{>Nj1MjI;Yi``EfBffp3D5j_2~73ALE1O{)O z0~I&LIDQ$IM-6~gx*RfpElcd0Q_)EB3Y71a59LD7&~rkn z2t;1OdvhA;%Y%wY)dC|i=QKQaS%#+XjD!8NRd7%H5z;sI91g!6g|^n;NY?xduzuYN z+BK7};f_3gt%PMah3ydzLk;|;xeFd0Vct%P=GvtZ(|FRFLy zq5}eVK@e{)QV#q=YJCUkCxr#HMs_y()K1B_t>2LQurVD9ol7sK+Mpk$J5lbGempGj zo3!kIfXDx;(;}UB;Klto8auKV=4^6DXO?cJcx({(Sbc{kW--Ji_&e0vGKPB`MJnFT z>m`m^S#+@d7b|vnC%q8aNfjldQSJLs7o5v|-;e_-y3~RO=;0;%g5e z{_~p=xUvNQXlsC;pR-W<6?4=#z8~?zXVWcH+I<9TK^5%Ipv>EzD0duD>*gf%p?VBj z?T&^&rnQrP`$T$;%Y*gm@}%yiC&|oDp?L{Tmlci(I5>{ZUiewMLL~h9Z~- zt+!yIO90I?RiRyzUg&nh1bMsa29XqzH^{LRrYcV?k*~&2LwVq|kB=7X8_bnpj=bd+H5z5?+Z0yO*E_ zix8N5LXi$1IY!3?Dv3z^BJyU*VmfHk1FB#;>PR6lRw)4+x6Od3XD)@MibXKkbvnxV zu?Swf5BJAI6tEmLLbsk1%n?BXTi)8Liu&#gN^dkVHyfV&dE4An(=(Xn(mL9*pY*GX7t2umV2`I^BRrV|Ed^ zrj^v>w~?29+QeDoE<^^=M7PhwKVpD@AMq2OIJbFbEw@j;I_Z3(0jEZ(Mt+j$+aghm-kZ}X5o?hV9 zaRb%_rLiZ1yV+QeS@^&<4J6S$h|=yjH7+BU&k1`+_jtOTgsW`VsTCa~g3HFB^>jqJYPkJ}$=u?uv+ zf*@yG_%CG_{2jIcJ`gMddS}(9y#6*uv8b8ZG$ahB1xkU;gNvE3pDl3K*#SJ9B}b;& z@{zXg?b!6DAqdK81!vn0Z`I96Lvo3LnnD3;TwN`sP?a(<;k@J^`>3$OVbKyec6NYySt5? z|MG$)z8@ylgMbtWmcrQTB`9B!A7!p`fEzsfK;=Dgh$RehWk?IT{^9}gx+ejDPRykq z>9?rRl9SZ8#}yX0O!)v0dr7-*Gx%-u1=4aCalnD<*EkFsI|e0J%7V&IdU|GtW(@Fdd4UoyH0o0Kz1wG6B>7C#jnB3t2o3nnC*0$^^Pv$2O zU+@VQ4PJmlQ~k+}X+cPNqXDT~TMCdv9kkUjqxo9HxcrZJxCXA{Y>sHQKKTijbXNlB)TVc8Rc;wLR%q8wBb@V zI(yMk2^(Jy zvt=4dPx}wpwJ#U7%H4!BeRrdhLw8})ijB0dq6PoDa*pU7y^gp`FQBqB6*S0XCyrE5 zAj;NiXtYKgCViHndioTL{nMwZE4k3;u^SqB;6SVQCDDZ0?a<-aGomI{NJQ?Qgp#U} zL}AHRv`(-ey%auytitBf0-Y1ox@j{T{&g3%xg}lU=CTCUw^gB{wtK|khYczkolc|{ z@S{i?4gtFpl?g9LT1lBxHxFlgcybL2k)1<(IO>$UFd3;HS^*!gwxm68rSR*fR;byy z5v`3EM^CMCq1I#;sNHCUJcT&eR1S94pu*B*i$(xRe#lBm3_oM=C37TEuGTVeY@99PZzb@hnnuAm*+Lm;O zuw?w%GTJqbHh8P|J}zKj))~iI$` znq|nbI-2@-r%}&G4^hrQG<5#F0!1I)4drf5Qpc@h#Bc`(YIwh*g;6Qwb=wo<> zs;{7*tBX+cwwVT}?INgNzY9&<^Aj#`4nZQpNoa$=0>~{0gj-H|l9OB|s&*p*lsNJu zHOEY{IW>>W=W@}IvL{N{zkwt}bI7}yL9}tIuX4fXS=3hA5|P!1;F#4;yzRm*q%yDOyCNx%>!>{B9a@~1gVO(o zq0h?FsIq4htu0?qRrZ{Q-P-qvp|uPRUZ8^lCzhhr2rW2jtdEj;TBt+Tou)Wi(8#V$ zsD6S>+K-fQe-(6& z*oWpShN3B76DW3-f?thRsLNdq72nJR7EQN^+TTA_e^mIVtSjpK{=L+e6SY zro}wZ;U`VEr`DQmKcYBiF*Y@-1nz%0Q0HX_D6y;r$Cq}3{d(U(slnYT{ZtjyxEHgg zbA9nqS0%jY&>Ud5?*}-zR07^SQVGgD#6jj<1IDEKWx4q*cicBuiim}olZ9IrlQ|pW zu#ZMO(__5=RQCU7EkEgz;LFA2^13%9dBY>pYgNkIk($K3Hc$cI4tN9EBe9@vUkZ?a z6%8^k?FL$R=7E_KQOr?u9&dE_L-xDJ99+9l3!k>sz-Om=1I#p@vrEpMX0gKyUU0f( znMsBQ*OLlzO%i{WjTw6JtXI6}>DU;t$F0w?cG0g`Ff13$Q7 znKZ%^lM7^VC1tVL2UdMoANz4%E?Xf}!PAVhW86OR zfuXwyNT*u?spk+#mVah!U*|H>CCVW8Niyh2js+#-zqn~txj0@^g8cKIs{NF5c=5Am z!TV1Rz?O{;@YshZAnlAXR+_b(*jBiZ{#ZzQX*@nTHx$TTQG(sq)!0OybdvbHdlB(}%%4sW2qAwv}hO{>Nq24@lbhzPPz(406+CjxUOe5w`|qrJA* zNY$ZfbkkcInyTN&<{3*M&%fSC)tH6tzy1+#-B&=ieJ9ek7=`&(SD@r|5!88J3jA3+ zogR9&o+!mt@^Vs)vDc?r$WC_FR4>vqvevPhMElo6k-8x)0yQWj_8B^Iuc7$YT~HV< zAQt!taQ-Ytw_V#!@y7~iGB*>XH;#anA9=`m28PXgxrF0;j(fsJgL>UJCTlhRVFnasG2 zwEM~<*mTbb@gkh4=hXhVlMmw{*KiMcPwtQ~@&?!N)sV5w5PHM78`iuzLOpZeQ{%|# zDDX=P>egP0YS;~Q&ix9qY{ZO~-4(4EOf!LBMb4m$*FuoOlzy9-^qBBw`9rtb-5yvP!OhFPVcZ% zyfYShd=QN#j>8M@o)M>sOyoCU-nwC4GBq<8BPvRdOo#bwLMDX%Qra;=&=kIE3&oBzVL zKSnzYpTgL%Bgo9Sg_`x;q(R4j&>KR(=?@WMqH*d8yUQB-@y>j*w!Rb1 zR``W7?nRO}pS5YA&`TW2dqZ>z2Vmata)Z_1qhQJH#nj>fBclnXl6AEV{wMl>#D3zc4D zi+;rKp=$Ym=yRn1Y-90 zh5TE%hYn_qlf^X>Bs2IhsW-SmJT)3gr${T1$&*K!@8n>5cQuwZ5k+rY-@=im5U?g< zCK1_}hNp9$;*!7mWY_o^EE-V+jy_Za{`^Y_b|Iuu_6fQCD3z3LSxuC6f-?<-E#ofC0OJ7|6^$tJ1s-r@03ty+$E>JY$t~h&YMyTHds@w6%IP~6 zr-OZfdt;g`|Mpi}PUWz(tYi3SvF*F3g~QUD+bi}S1Mb~o-uc$>;SRyZ={pKF8+L5^ zueZZmck=eT9lYjEUwJJy3*=bzR?M)lQ(kVd<I1-cSq7Z&Fu|xZRQPAv@AYrS6V#a>bFq4*luy}tFMJJ z>v8iPVUM@f8A|RbudCXT8#H|f=$g%K~ zQ;<*akOxVi00QTzpc{sK9)}u=L6d+-bl@-y?FbNbGm($j&_pp)5aNk3l2v@&fP#a3OP}u8a0K0QNSZnP~#Rkb;_Z* zO27rDtDt!T-7MrZq>EzKB;00UPu&6DY-~DEBjlJ7hQZGAfzi$2gv!8R;R;a<)CW2z Zhheungbz9Og%xz}2?GZZg49FQ0suW!0^I-r literal 0 HcmV?d00001 diff --git a/scripts/testing/test_spatial.sh b/scripts/testing/test_spatial.sh index 676fec76f..e87114805 100755 --- a/scripts/testing/test_spatial.sh +++ b/scripts/testing/test_spatial.sh @@ -1,17 +1,28 @@ #!/usr/bin/env bash set -euo pipefail +set -x + +H_ARG=${1:-} +W_ARG=${2:-} + +H=${H_ARG:-${FME_DISTRIBUTED_H:-2}} +W=${W_ARG:-${FME_DISTRIBUTED_W:-2}} -H=${FME_DISTRIBUTED_H:-2} -W=${FME_DISTRIBUTED_W:-2} NP=$((H * W)) +dir=fme/core/distributed/parallel_tests +tests=test_backward_step.py::test_spatial_parallel_backward_step +pytest_cmd="pytest -s $dir/$tests" +file="testdata/backward_step_baseline.pt" + +if [ -f "$dir/$file" ]; then + rm "$dir/$file" +fi + +$pytest_cmd + export FME_DISTRIBUTED_BACKEND=model export FME_DISTRIBUTED_H=$H export FME_DISTRIBUTED_W=$W - -# torchrun --standalone --nnodes=1 --nproc_per_node=$NP \ -# -m pytest fme/core/distributed/parallel_tests/test_spatial.py "$@" - -torchrun --standalone --nnodes=1 --nproc_per_node=$NP \ - -m pytest fme/core/distributed/parallel_tests/test_backward_step.py::test_spatial_parallel_backward_step "$@" +torchrun --standalone --nnodes=1 --nproc-per-node=$NP -m $pytest_cmd From 85bf4ec3b014a20d27838648ba0c102088d1b494 Mon Sep 17 00:00:00 2001 From: peterdschwartz Date: Fri, 20 Mar 2026 10:42:34 -0400 Subject: [PATCH 4/4] add tests explicitly comparing gradients for backward step --- .../distributed/parallel_tests/test_step.py | 221 ++++++++++++------ scripts/testing/test_spatial.sh | 4 + 2 files changed, 159 insertions(+), 66 deletions(-) diff --git a/fme/core/distributed/parallel_tests/test_step.py b/fme/core/distributed/parallel_tests/test_step.py index b175c4499..a052e2a43 100644 --- a/fme/core/distributed/parallel_tests/test_step.py +++ b/fme/core/distributed/parallel_tests/test_step.py @@ -18,10 +18,18 @@ import numpy as np import pytest import torch +import xarray as xr from torch import nn import fme +from fme.ace.data_loading.batch_data import BatchData from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder +from fme.ace.stepper.single_module import ( + StepperConfig, + TrainOutput, + TrainStepper, + TrainStepperConfig, +) from fme.ace.testing.fv3gfs_data import get_scalar_dataset from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig @@ -29,13 +37,20 @@ from fme.core.distributed.distributed import Distributed from fme.core.distributed.non_distributed import DummyWrapper from fme.core.labels import BatchLabels +from fme.core.loss import StepLossConfig from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig from fme.core.optimization import Optimization, OptimizationConfig, SchedulerConfig from fme.core.registry import ModuleSelector +from fme.core.step import SingleModuleStepConfig, StepSelector from fme.core.step.args import StepArgs from fme.core.step.multi_call import MultiCallConfig, MultiCallStepConfig from fme.core.step.secondary_decoder import SecondaryDecoderConfig -from fme.core.step.single_module import SingleModuleStepConfig + +# from fme.core.step.single_module import ( +# SingleModuleStepConfig, +# TrainOutput, +# TrainStepper, +# ) from fme.core.step.step import StepABC, StepSelector from fme.core.typing_ import TensorDict @@ -44,6 +59,32 @@ DATA_DIR = pathlib.Path(__file__).parent / "testdata" +def get_dataset_info( + img_shape=(5, 5), +) -> DatasetInfo: + horizontal_coordinate = LatLonCoordinates( + lat=torch.zeros(img_shape[-2]), + lon=torch.zeros(img_shape[-1]), + ) + vertical_coordinate = HybridSigmaPressureCoordinate( + ak=torch.arange(7), bk=torch.arange(7) + ) + return DatasetInfo( + horizontal_coordinates=horizontal_coordinate, + vertical_coordinate=vertical_coordinate, + timestep=TIMESTEP, + ) + + +def _get_train_stepper( + stepper_config: StepperConfig, + dataset_info: DatasetInfo, + **train_config_kwargs, +) -> TrainStepper: + train_config = TrainStepperConfig(**train_config_kwargs) + return train_config.get_train_stepper(stepper_config, dataset_info) + + def get_network_and_loss_normalization_config( names: list[str], dir: pathlib.Path | None = None, @@ -469,96 +510,151 @@ def test_step_regression( cache_step_output(output, DATA_DIR / f"{case_name}_output.pt") -def _run_step_optimization_backward( + +def _run_stepper_backward_with_optimization( img_shape: tuple[int, int], n_samples: int, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ - Run a single forward + backward through a Step using Optimization, - and return: - - scalar loss on this rank - - gradients for all parameters (CPU tensors) + Single forward + backward through TrainStepper using Optimization. + Returns scalar loss and per-parameter gradients (CPU tensors). """ + torch.manual_seed(0) device = fme.get_device() dist = Distributed.get_instance() - selector = get_single_module_noise_conditioned_selector(None) - step = get_step(selector, img_shape) + # Reuse the same config pattern as get_regression_stepper_and_data + in_names = ["a", "b"] + out_names = ["b", "c"] + n_forward_steps = 2 + all_names = list(set(in_names + out_names)) + + loss_cfg = StepLossConfig(type="AreaWeightedMSE") + + stepper_config = StepperConfig( + step=StepSelector( + type="single_module", + config=dataclasses.asdict( + SingleModuleStepConfig( + builder=ModuleSelector( + type="NoiseConditionedSFNO", + config=dataclasses.asdict( + NoiseConditionedSFNOBuilder( + embed_dim=16, + num_layers=2, + noise_embed_dim=16, + noise_type="isotropic", + ) + ), + ), + in_names=in_names, + out_names=out_names, + normalization=NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={n: 0.1 for n in all_names}, + stds={n: 1.1 for n in all_names}, + ), + ), + ocean=None, + ) + ), + ), + ) + + dataset_info = get_dataset_info(img_shape=img_shape) - modules = nn.ModuleList(step.modules) + train_stepper: TrainStepper = _get_train_stepper( + stepper_config, + dataset_info, + loss=loss_cfg, + ) + + # Random data, same as regression helper + data = BatchData.new_on_device( + data={ + "a": torch.randn(n_samples, n_forward_steps + 1, *img_shape, device=device), + "b": torch.randn(n_samples, n_forward_steps + 1, *img_shape, device=device), + "c": torch.randn(n_samples, n_forward_steps + 1, *img_shape, device=device), + }, + time=xr.DataArray( + np.zeros((n_samples, n_forward_steps + 1)), + dims=["sample", "time"], + ), + labels=None, + epoch=0, + horizontal_dims=["lat", "lon"], + ) + data = data.scatter_spatial(img_shape) + + # Build Optimization on train_stepper.modules opt_config = OptimizationConfig( optimizer_type="Adam", lr=1e-3, - scheduler=SchedulerConfig(), + use_gradient_accumulation=True, enable_automatic_mixed_precision=False, ) - optimization = opt_config.build(modules, max_epochs=1) - optimization.set_mode(modules) - # Global inputs; Distributed backend decides what scatter_spatial does - input_data: TensorDict = get_tensor_dict(step.input_names, img_shape, n_samples) - next_input_data: TensorDict = get_tensor_dict( - step.next_step_input_names, img_shape, n_samples + optimization = opt_config.build(train_stepper.modules, max_epochs=1) + optimization.set_mode(train_stepper.modules) + + # Forward + backward through the *training* API + # train_output: TrainOutput = train_stepper.train_on_batch(data, optimization) + + # --- Manual version of train_on_batch, but WITHOUT step_weights() --- + train_stepper._init_for_epoch(data.epoch) + metrics: dict[str, torch.Tensor] = {} + + input_data = data.get_start(train_stepper._prognostic_names, train_stepper.n_ic_timesteps) + target_data = train_stepper._stepper.get_forward_data( + data, compute_derived_variables=False ) + data = train_stepper._stepper.forcing_deriver(data) - # Use the same scatter pattern as the step tests; in serial this is a no-op - input_data = dist.scatter_spatial(input_data, img_shape) - next_input_data = dist.scatter_spatial(next_input_data, img_shape) + optimization.set_mode(train_stepper._stepper.modules) - # Forward - out = step.step( - args=StepArgs( - input=input_data, - next_step_input_data=next_input_data, - labels=None, - ), - wrapper=lambda x: x, + output_list = train_stepper._accumulate_loss( + input_data=input_data, + data=data, + target_data=target_data, + optimization=optimization, + metrics=metrics, ) - # Use the real training loss from the step output - if isinstance(out, dict) and "loss" in out: - loss = out["loss"] - else: - raise RuntimeError( - "Step output does not contain 'loss'; " - "wire this test to the real training loss output." - ) + regularizer_loss = train_stepper._stepper.get_regularizer_loss() + if torch.any(regularizer_loss > 0): + optimization.accumulate_loss(regularizer_loss) - # Route loss through Optimization, but only run backward (no step/zero yet) - optimization.accumulate_loss(loss) - total_loss = optimization.get_accumulated_loss() - optimization._backward(total_loss) + loss = optimization.get_accumulated_loss() - # Collect parameter grads grads: dict[str, torch.Tensor] = {} - for name, p in step.named_parameters(): - if p.grad is not None: - grads[name] = p.grad.detach().cpu().clone() + for i, wrapped in enumerate(train_stepper.modules): + module = getattr(wrapped, "module", wrapped) + for name, p in module.named_parameters(): + if p.grad is not None: + grads[f"module_{i}.{name}"] = p.grad.detach().cpu().clone() + - return total_loss.detach().cpu(), grads + return loss.detach().cpu(), grads @pytest.mark.parallel -def test_step_optimization_backward_matches_baseline(): +def test_stepper_backward_with_optimization(): """ - Regression test for Step+Optimization backward under spatial parallelism. + Test compares gradients after backward step with and without spatial parallelism. - Uses the same forward/Optimization path in both serial and spatial - backends; serial run is used to create the baseline. Spatial backends - must reproduce the baseline loss and parameter gradients element-wise. + Since each rank holds the entire global model's parameters, there is no need to gather. + This test will need to be modified once spatial sharding is implemented for parameters. """ DATA_DIR = pathlib.Path(__file__).parent / "testdata" - BASELINE_FILE = DATA_DIR / "backward_with_opt_baseline.pt" - + BASELINE_FILE = DATA_DIR / "csfno_stepper_backward_with_opt_baseline.pt" dist = Distributed.get_instance() torch.manual_seed(0) img_shape = (20, 40) n_samples = 2 - loss, grads = _run_step_optimization_backward(img_shape, n_samples) + loss, grads = _run_stepper_backward_with_optimization(img_shape, n_samples) - # Only root rank writes/compares baseline if not dist.is_root(): return @@ -573,10 +669,8 @@ def test_step_optimization_backward_matches_baseline(): }, BASELINE_FILE, ) - raise AssertionError( - f"Baseline created at {BASELINE_FILE}. " - "Re-run the test to perform regression check." - ) + print("Created Baseline file") + return baseline = torch.load(BASELINE_FILE, map_location="cpu") assert tuple(baseline["img_shape"]) == tuple(img_shape) @@ -585,9 +679,8 @@ def test_step_optimization_backward_matches_baseline(): baseline_loss = baseline["loss"] baseline_grads: dict[str, torch.Tensor] = baseline["grads"] - # 1) Loss finite and close to baseline. + # Loss check assert torch.isfinite(loss), "Loss is not finite on this rank" - actual_loss = loss.item() expected_loss = baseline_loss.item() rel_loss = abs(actual_loss - expected_loss) / max(abs(expected_loss), 1e-12) @@ -597,12 +690,8 @@ def test_step_optimization_backward_matches_baseline(): f"rel_diff={rel_loss:.3e}" ) - # 2) Parameter gradients match baseline element-wise. - # Require same parameter set as baseline - assert set(grads.keys()) == set( - baseline_grads.keys() - ), "Parameter set changed since baseline generation" - + # Grad check + assert set(grads.keys()) == set(baseline_grads.keys()) for name in sorted(grads.keys()): g = grads[name] g_ref = baseline_grads[name] @@ -610,7 +699,7 @@ def test_step_optimization_backward_matches_baseline(): diff = (g - g_ref).abs() max_abs = diff.max().item() max_rel = (diff / g_ref.abs().clamp_min(1e-12)).max().item() - assert torch.allclose(g, g_ref, rtol=1e-6, atol=1e-7), ( + assert torch.allclose(g, g_ref, rtol=1e-6, atol=1e-8), ( f"Gradient for '{name}' deviates from baseline: " f"max_abs={max_abs:.3e}, max_rel={max_rel:.3e}" ) diff --git a/scripts/testing/test_spatial.sh b/scripts/testing/test_spatial.sh index e87114805..3ada4026c 100755 --- a/scripts/testing/test_spatial.sh +++ b/scripts/testing/test_spatial.sh @@ -11,9 +11,13 @@ W=${W_ARG:-${FME_DISTRIBUTED_W:-2}} NP=$((H * W)) dir=fme/core/distributed/parallel_tests +# dir=fme/ace/stepper/ tests=test_backward_step.py::test_spatial_parallel_backward_step +# tests=test_step.py::test_stepper_backward_with_optimization +# tests=test_single_module_csfno.py::test_stepper_train_on_batch_with_optimization_regression pytest_cmd="pytest -s $dir/$tests" file="testdata/backward_step_baseline.pt" +# file="testdata/csfno_stepper_backward_with_opt_baseline.pt" if [ -f "$dir/$file" ]; then rm "$dir/$file"