From 1daa10b33238d83b1208d1cf6ec077fe13c4a236 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Wed, 3 Dec 2025 17:57:18 +0100 Subject: [PATCH 1/2] implement autoregressive condition, time_weighting, solver --- pina/condition/__init__.py | 3 + pina/condition/autoregressive_condition.py | 91 ++++++++++++++++++ pina/loss/__init__.py | 10 ++ pina/loss/time_weighting.py | 57 ++++++++++++ pina/loss/time_weighting_interface.py | 24 +++++ pina/solver/__init__.py | 5 + pina/solver/autoregressive_solver/__init__.py | 4 + .../autoregressive_solver.py | 88 ++++++++++++++++++ .../autoregressive_solver_interface.py | 93 +++++++++++++++++++ 9 files changed, 375 insertions(+) create mode 100644 pina/condition/autoregressive_condition.py create mode 100644 pina/loss/time_weighting.py create mode 100644 pina/loss/time_weighting_interface.py create mode 100644 pina/solver/autoregressive_solver/__init__.py create mode 100644 pina/solver/autoregressive_solver/autoregressive_solver.py create mode 100644 pina/solver/autoregressive_solver/autoregressive_solver_interface.py diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..502c34ae9 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -15,6 +15,7 @@ "DataCondition", "GraphDataCondition", "TensorDataCondition", + "AutoregressiveCondition", ] from .condition_interface import ConditionInterface @@ -37,3 +38,5 @@ GraphDataCondition, TensorDataCondition, ) + +from .autoregressive_condition import AutoregressiveCondition diff --git a/pina/condition/autoregressive_condition.py b/pina/condition/autoregressive_condition.py new file mode 100644 index 000000000..1d48b638d --- /dev/null +++ b/pina/condition/autoregressive_condition.py @@ -0,0 +1,91 @@ +import torch +from .condition_interface import ConditionInterface +from ..loss import TimeWeightingInterface, ConstantTimeWeighting +from ..utils import check_consistency + + +class AutoregressiveCondition(ConditionInterface): + """ + A specialized condition for autoregressive tasks. + It generates input/unroll pairs from a single time-series tensor. + """ + + __slots__ = ["input", "unroll"] + + def __init__( + self, + data, + unroll_length, + num_unrolls=None, + randomize=True, + time_weighting=None, + ): + """ + Create an AutoregressiveCondition. + """ + super().__init__() + + self._n_timesteps, n_features = data.shape + self._unroll_length = unroll_length + self._requested_num_unrolls = num_unrolls + self._randomize = randomize + + # time weighting: weight the loss differently along the unroll + if time_weighting is None: + self._time_weighting = ConstantTimeWeighting() + else: + check_consistency(time_weighting, TimeWeightingInterface) + self._time_weighting = time_weighting + + # windows creation + initial_data = [] + unroll_data = [] + + for starting_index in self.starting_indices: + initial_data.append(data[starting_index]) + target_start = starting_index + 1 + unroll_data.append( + data[target_start : target_start + self._unroll_length, :] + ) + + self.input = torch.stack(initial_data) # [num_unrolls, features] + self.unroll = torch.stack( + unroll_data + ) # [num_unrolls, unroll_length, features] + + @property + def unroll_length(self): + return self._unroll_length + + @property + def time_weighting(self): + return self._time_weighting + + @property + def max_start_idx(self): + max_start_idx = self._n_timesteps - self._unroll_length + assert max_start_idx > 0, "Provided data sequence too short" + return max_start_idx + + @property + def num_unrolls(self): + if self._requested_num_unrolls is None: + return self.max_start_idx + else: + assert ( + self._requested_num_unrolls < self.max_start_idx + ), "too many samples requested" + return self._requested_num_unrolls + + @property + def starting_indices(self): + all_starting_indices = torch.arange(self.max_start_idx) + + if self._randomize: + perm = torch.randperm(len(all_starting_indices)) + return all_starting_indices[perm[: self.num_unrolls]] + else: + selected_indices = torch.linspace( + 0, len(all_starting_indices) - 1, self.num_unrolls + ).long() + return all_starting_indices[selected_indices] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index d91cf7ab0..2d8ab288e 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -9,6 +9,10 @@ "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", "LinearWeighting", + "TimeWeightingInterface", + "ConstantTimeWeighting", + "ExponentialTimeWeighting", + "LinearTimeWeighting", ] from .loss_interface import LossInterface @@ -19,3 +23,9 @@ from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting from .linear_weighting import LinearWeighting +from .time_weighting_interface import TimeWeightingInterface +from .time_weighting import ( + ConstantTimeWeighting, + ExponentialTimeWeighting, + LinearTimeWeighting, +) diff --git a/pina/loss/time_weighting.py b/pina/loss/time_weighting.py new file mode 100644 index 000000000..0b1d1ed65 --- /dev/null +++ b/pina/loss/time_weighting.py @@ -0,0 +1,57 @@ +"""Module for the Time Weighting.""" + +import torch +from .time_weighting_interface import TimeWeightingInterface + + +class ConstantTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme that assigns equal weight to all time steps. + """ + + def __call__(self, num_steps, device): + return torch.ones(num_steps, device=device) / num_steps + + +class ExponentialTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme change exponentially with time. + gamma > 1.0: increasing weights + 0 < gamma < 1.0: decreasing weights + weight at time t is gamma^t + """ + + def __init__(self, gamma=0.9): + """ + Initialization of the :class:`ExponentialTimeWeighting` class. + :param float gamma: The decay factor. Default is 0.9. + """ + self.gamma = gamma + + def __call__(self, num_steps, device): + steps = torch.arange(num_steps, device=device, dtype=torch.float32) + weights = self.gamma**steps + return weights / weights.sum() + + +class LinearTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme that changes linearly from a start weight to an end weight. + """ + + def __init__(self, start=0.1, end=1.0): + """ + Initialization of the :class:`LinearDecayTimeWeighting` class. + + :param float start: The starting weight. Default is 0.1. + :param float end: The ending weight. Default is 1.0. + """ + self.start = start + self.end = end + + def __call__(self, num_steps, device): + if num_steps == 1: + return torch.ones(1, device=device) + + weights = torch.linspace(self.start, self.end, num_steps, device=device) + return weights / weights.sum() diff --git a/pina/loss/time_weighting_interface.py b/pina/loss/time_weighting_interface.py new file mode 100644 index 000000000..9d9781351 --- /dev/null +++ b/pina/loss/time_weighting_interface.py @@ -0,0 +1,24 @@ +"""Module for the Time Weighting Interface.""" + +from abc import ABCMeta, abstractmethod +import torch + + +class TimeWeightingInterface(metaclass=ABCMeta): + """ + Abstract base class for all time weighting schemas. All time weighting + schemas should inherit from this class. + """ + + @abstractmethod + def __call__(self, num_steps, device): + """ + Compute the weights for the time steps. + + :param int num_steps: The number of time steps. + :param torch.device device: The device on which the weights should be + created. + :return: The weights for the time steps. + :rtype: torch.Tensor + """ + pass diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,7 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..9ef7c43e1 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..d0a46c310 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,88 @@ +import torch +from torch.nn.modules.loss import _Loss + +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import AutoregressiveCondition +from pina.loss import ( + LossInterface, + TimeWeightingInterface, + ConstantTimeWeighting, +) +from .autoregressive_solver_interface import AutoregressiveSolverInterface + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + """ + Autoregressive Solver class. + """ + + accepted_conditions_types = AutoregressiveCondition + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + """ + super().__init__( + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + def loss_data(self, input, target, unroll_length, time_weighting): + """ + Compute the data loss for the recursive autoregressive solver. + This will be applied to each condition individually. + """ + steps_to_predict = unroll_length - 1 + # weights are passed from the condition + weights = time_weighting(steps_to_predict, device=input.device) + + total_loss = 0.0 + current_state = input + + for step in range(steps_to_predict): + + predicted_next_state = self.forward( + current_state + ) # [batch_size, features] + actual_next_state = target[:, step, :] # [batch_size, features] + + step_loss = self.loss(predicted_next_state, actual_next_state) + + total_loss += step_loss * weights[step] + + current_state = predicted_next_state.detach() + + return total_loss + + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + predictions = [current_state] # Store initial state without batch dim + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) # Keep batch dim for storage + current_state = next_state + + return torch.stack(predictions) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..e895705fe --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,93 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...loss import TimeWeightingInterface, ConstantTimeWeighting +from ...condition import AutoregressiveCondition + + +class AutoregressiveSolverInterface(SolverInterface): + + accepted_conditions_types = AutoregressiveCondition + + def __init__(self, loss=None, **kwargs): + + if loss is None: + loss = torch.nn.MSELoss() + + super().__init__(**kwargs) + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each conditions and each time applies the specialized loss_data function. + """ + + condition_loss = {} + for condition_name, points in batch: + condition = self.problem.conditions[condition_name] + + unroll_length = getattr(condition, "unroll_length", None) + time_weighting = getattr(condition, "time_weighting", None) + + if "unroll" in points: + loss = self.loss_data( + points["input"], + points["unroll"], + unroll_length, + time_weighting, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, target, unroll_length, time_weighting): + """ + Computes the data loss for each condition. + N.B.: unroll_length and time_weighting are attributes of the condition. + + :param torch.Tensor input: Initial states. + :param torch.Tensor target: Target sequences. + :param int unroll_length: The number of steps to unroll (attribute of the condition). + :param TimeWeightingInterface time_weighting: The time weighting strategy (attribute of the condition). + :return: The average loss over all unroll steps. + :rtype: torch.Tensor + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn + + @property + def time_weighting(self): + """ + The time weighting strategy. + """ + return self._time_weighting From 1bb9d468645f1b040ab136d6eb4e843a26c55f88 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 16 Dec 2025 18:12:41 +0100 Subject: [PATCH 2/2] implement everything into solver --- autoregressive_prova_generic_condition.py | 149 ++++++++++++++++++ pina/condition/__init__.py | 3 - pina/condition/autoregressive_condition.py | 91 ----------- pina/loss/__init__.py | 10 -- pina/loss/time_weighting.py | 57 ------- pina/loss/time_weighting_interface.py | 24 --- .../autoregressive_solver.py | 142 +++++++++++++---- .../autoregressive_solver_interface.py | 54 +++---- 8 files changed, 285 insertions(+), 245 deletions(-) create mode 100644 autoregressive_prova_generic_condition.py delete mode 100644 pina/condition/autoregressive_condition.py delete mode 100644 pina/loss/time_weighting.py delete mode 100644 pina/loss/time_weighting_interface.py diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py new file mode 100644 index 000000000..3c0796bbc --- /dev/null +++ b/autoregressive_prova_generic_condition.py @@ -0,0 +1,149 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 20), + torch.nn.ReLU(), + torch.nn.Dropout(0.2), + torch.nn.Linear(20, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + + +class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition_0":DataCondition(input=y), + "data_condition_1":DataCondition(input=y), + } + +problem = Problem() + +#for each condition, define unroll instructions with these keys: +# - unroll_length: length of each unroll window +# - num_unrolls: number of unroll windows to create (if None, use all possible) +# - randomize: whether to randomize the starting indices of the unroll windows +unroll_instructions = { + "data_condition_0": { + "unroll_length": 10, + "num_unrolls": 89, + "randomize": True, + "eps": 5.0 + }, + "data_condition_1": { + "unroll_length": 20, + "num_unrolls": 79, + "randomize": True, + "eps": 10.0 + }, +} + +solver = AutoregressiveSolver( + unroll_instructions=unroll_instructions, + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), + eps=10.0, +) + +trainer = Trainer( + solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False +) +trainer.train() + +# ============================================================================ +# VISUALIZATION +# ============================================================================ + +test_start_idx = 50 +num_prediction_steps = 30 + +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# viauzlize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 502c34ae9..4e57811fb 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -15,7 +15,6 @@ "DataCondition", "GraphDataCondition", "TensorDataCondition", - "AutoregressiveCondition", ] from .condition_interface import ConditionInterface @@ -38,5 +37,3 @@ GraphDataCondition, TensorDataCondition, ) - -from .autoregressive_condition import AutoregressiveCondition diff --git a/pina/condition/autoregressive_condition.py b/pina/condition/autoregressive_condition.py deleted file mode 100644 index 1d48b638d..000000000 --- a/pina/condition/autoregressive_condition.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -from .condition_interface import ConditionInterface -from ..loss import TimeWeightingInterface, ConstantTimeWeighting -from ..utils import check_consistency - - -class AutoregressiveCondition(ConditionInterface): - """ - A specialized condition for autoregressive tasks. - It generates input/unroll pairs from a single time-series tensor. - """ - - __slots__ = ["input", "unroll"] - - def __init__( - self, - data, - unroll_length, - num_unrolls=None, - randomize=True, - time_weighting=None, - ): - """ - Create an AutoregressiveCondition. - """ - super().__init__() - - self._n_timesteps, n_features = data.shape - self._unroll_length = unroll_length - self._requested_num_unrolls = num_unrolls - self._randomize = randomize - - # time weighting: weight the loss differently along the unroll - if time_weighting is None: - self._time_weighting = ConstantTimeWeighting() - else: - check_consistency(time_weighting, TimeWeightingInterface) - self._time_weighting = time_weighting - - # windows creation - initial_data = [] - unroll_data = [] - - for starting_index in self.starting_indices: - initial_data.append(data[starting_index]) - target_start = starting_index + 1 - unroll_data.append( - data[target_start : target_start + self._unroll_length, :] - ) - - self.input = torch.stack(initial_data) # [num_unrolls, features] - self.unroll = torch.stack( - unroll_data - ) # [num_unrolls, unroll_length, features] - - @property - def unroll_length(self): - return self._unroll_length - - @property - def time_weighting(self): - return self._time_weighting - - @property - def max_start_idx(self): - max_start_idx = self._n_timesteps - self._unroll_length - assert max_start_idx > 0, "Provided data sequence too short" - return max_start_idx - - @property - def num_unrolls(self): - if self._requested_num_unrolls is None: - return self.max_start_idx - else: - assert ( - self._requested_num_unrolls < self.max_start_idx - ), "too many samples requested" - return self._requested_num_unrolls - - @property - def starting_indices(self): - all_starting_indices = torch.arange(self.max_start_idx) - - if self._randomize: - perm = torch.randperm(len(all_starting_indices)) - return all_starting_indices[perm[: self.num_unrolls]] - else: - selected_indices = torch.linspace( - 0, len(all_starting_indices) - 1, self.num_unrolls - ).long() - return all_starting_indices[selected_indices] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 2d8ab288e..d91cf7ab0 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -9,10 +9,6 @@ "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", "LinearWeighting", - "TimeWeightingInterface", - "ConstantTimeWeighting", - "ExponentialTimeWeighting", - "LinearTimeWeighting", ] from .loss_interface import LossInterface @@ -23,9 +19,3 @@ from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting from .linear_weighting import LinearWeighting -from .time_weighting_interface import TimeWeightingInterface -from .time_weighting import ( - ConstantTimeWeighting, - ExponentialTimeWeighting, - LinearTimeWeighting, -) diff --git a/pina/loss/time_weighting.py b/pina/loss/time_weighting.py deleted file mode 100644 index 0b1d1ed65..000000000 --- a/pina/loss/time_weighting.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Module for the Time Weighting.""" - -import torch -from .time_weighting_interface import TimeWeightingInterface - - -class ConstantTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme that assigns equal weight to all time steps. - """ - - def __call__(self, num_steps, device): - return torch.ones(num_steps, device=device) / num_steps - - -class ExponentialTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme change exponentially with time. - gamma > 1.0: increasing weights - 0 < gamma < 1.0: decreasing weights - weight at time t is gamma^t - """ - - def __init__(self, gamma=0.9): - """ - Initialization of the :class:`ExponentialTimeWeighting` class. - :param float gamma: The decay factor. Default is 0.9. - """ - self.gamma = gamma - - def __call__(self, num_steps, device): - steps = torch.arange(num_steps, device=device, dtype=torch.float32) - weights = self.gamma**steps - return weights / weights.sum() - - -class LinearTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme that changes linearly from a start weight to an end weight. - """ - - def __init__(self, start=0.1, end=1.0): - """ - Initialization of the :class:`LinearDecayTimeWeighting` class. - - :param float start: The starting weight. Default is 0.1. - :param float end: The ending weight. Default is 1.0. - """ - self.start = start - self.end = end - - def __call__(self, num_steps, device): - if num_steps == 1: - return torch.ones(1, device=device) - - weights = torch.linspace(self.start, self.end, num_steps, device=device) - return weights / weights.sum() diff --git a/pina/loss/time_weighting_interface.py b/pina/loss/time_weighting_interface.py deleted file mode 100644 index 9d9781351..000000000 --- a/pina/loss/time_weighting_interface.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Module for the Time Weighting Interface.""" - -from abc import ABCMeta, abstractmethod -import torch - - -class TimeWeightingInterface(metaclass=ABCMeta): - """ - Abstract base class for all time weighting schemas. All time weighting - schemas should inherit from this class. - """ - - @abstractmethod - def __call__(self, num_steps, device): - """ - Compute the weights for the time steps. - - :param int num_steps: The number of time steps. - :param torch.device device: The device on which the weights should be - created. - :return: The weights for the time steps. - :rtype: torch.Tensor - """ - pass diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index d0a46c310..0606a3fd6 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -1,14 +1,7 @@ import torch -from torch.nn.modules.loss import _Loss - from pina.utils import check_consistency from pina.solver.solver import SingleSolverInterface -from pina.condition import AutoregressiveCondition -from pina.loss import ( - LossInterface, - TimeWeightingInterface, - ConstantTimeWeighting, -) +from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface @@ -19,12 +12,14 @@ class AutoregressiveSolver( Autoregressive Solver class. """ - accepted_conditions_types = AutoregressiveCondition + accepted_conditions_types = DataCondition def __init__( self, + unroll_instructions, problem, model, + eps=None, loss=None, optimizer=None, scheduler=None, @@ -33,8 +28,19 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The model to be trained. + :param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss. + :param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed. + :param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed. + :param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied. + :param bool use_lt: Whether to use learning rate tuning. """ + super().__init__( + unroll_instructions=unroll_instructions, problem=problem, model=model, loss=loss, @@ -44,45 +50,123 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, input, target, unroll_length, time_weighting): + def loss_data(self, data, condition_unroll_instructions): """ Compute the data loss for the recursive autoregressive solver. This will be applied to each condition individually. + :param torch.Tensor data: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. + :rtype: torch.Tensor """ - steps_to_predict = unroll_length - 1 - # weights are passed from the condition - weights = time_weighting(steps_to_predict, device=input.device) - total_loss = 0.0 - current_state = input + initial_data, unroll_data = self.create_unroll_windows( + data, condition_unroll_instructions + ) + + unroll_length = condition_unroll_instructions["unroll_length"] + current_state = initial_data # [num_unrolls, features] + + losses = [] + for step in range(unroll_length): + + predicted_state = self.forward(current_state) # [num_unrolls, features] + target_state = unroll_data[:, step, :] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions) + + weighted_loss = (step_losses * weights).sum() + return weighted_loss - for step in range(steps_to_predict): + def create_unroll_windows(self, data, condition_unroll_instructions): + """ + Create unroll windows for each condition from the data based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tuple of initial data and unroll data tensors. + :rtype: (torch.Tensor, torch.Tensor) + """ - predicted_next_state = self.forward( - current_state - ) # [batch_size, features] - actual_next_state = target[:, step, :] # [batch_size, features] + unroll_length = condition_unroll_instructions["unroll_length"] + + start_list = [] + unroll_list = [] + for starting_index in self.decide_starting_indices( + data, condition_unroll_instructions + ): + idx = starting_index.item() + start = data[idx] + target_start = idx + 1 + unroll = data[target_start : target_start + unroll_length, :] + start_list.append(start) + unroll_list.append(unroll) + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] + return initial_data, unroll_data + + def decide_starting_indices(self, data, condition_unroll_instructions): + """ + Decide the starting indices for unrolling based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tensor of starting indices. + :rtype: torch.Tensor + """ + n_step, n_features = data.shape + num_unrolls = condition_unroll_instructions.get("num_unrolls", None) + unroll_length = condition_unroll_instructions["unroll_length"] + randomize = condition_unroll_instructions.get("randomize", True) - step_loss = self.loss(predicted_next_state, actual_next_state) + max_start = n_step - unroll_length + indices = torch.arange(max_start, device=data.device) - total_loss += step_loss * weights[step] + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] - current_state = predicted_next_state.detach() + if randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] - return total_loss + return indices + + def compute_adaptive_weights(self, step_losses, condition_unroll_instructions): + """ + Compute adaptive weights for each time step based on cumulative losses. + :param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step. + :return: Tensor of shape [unroll_length] containing normalized weights. + :rtype: torch.Tensor + """ + num_steps = len(step_losses) + eps = condition_unroll_instructions.get("eps", None) + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + + return weights / weights.sum() def predict(self, initial_state, num_steps): """ Make recursive predictions starting from an initial state. + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode - + current_state = initial_state - predictions = [current_state] # Store initial state without batch dim + predictions = [current_state] + with torch.no_grad(): for step in range(num_steps): next_state = self.forward(current_state) - predictions.append(next_state) # Keep batch dim for storage + predictions.append(next_state) current_state = next_state - - return torch.stack(predictions) + + return torch.stack(predictions) \ No newline at end of file diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index e895705fe..d0a6f919a 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -7,58 +7,57 @@ from ..solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface -from ...loss import TimeWeightingInterface, ConstantTimeWeighting -from ...condition import AutoregressiveCondition +from ...condition import DataCondition class AutoregressiveSolverInterface(SolverInterface): - accepted_conditions_types = AutoregressiveCondition + def __init__(self, unroll_instructions, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param loss: The loss function to be minimized. If None, defaults to MSELoss. + :type loss: torch.nn.Module or LossInterface, optional + """ - def __init__(self, loss=None, **kwargs): + super().__init__(**kwargs) if loss is None: loss = torch.nn.MSELoss() - super().__init__(**kwargs) - check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss + self._unroll_instructions = unroll_instructions def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. Iterates over each conditions and each time applies the specialized loss_data function. + :param dict batch: A dictionary mapping condition names to data batches. + :return: A dictionary mapping condition names to computed loss values. + :rtype: dict """ condition_loss = {} for condition_name, points in batch: - condition = self.problem.conditions[condition_name] - - unroll_length = getattr(condition, "unroll_length", None) - time_weighting = getattr(condition, "time_weighting", None) - - if "unroll" in points: - loss = self.loss_data( + condition_unroll_instructions = self._unroll_instructions[condition_name] + loss = self.loss_data( points["input"], - points["unroll"], - unroll_length, - time_weighting, + condition_unroll_instructions, ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, target, unroll_length, time_weighting): + def loss_data(self, input, condition_unroll_instructions): """ Computes the data loss for each condition. - N.B.: unroll_length and time_weighting are attributes of the condition. + N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. - :param torch.Tensor input: Initial states. - :param torch.Tensor target: Target sequences. - :param int unroll_length: The number of steps to unroll (attribute of the condition). - :param TimeWeightingInterface time_weighting: The time weighting strategy (attribute of the condition). - :return: The average loss over all unroll steps. + :param torch.Tensor input: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. :rtype: torch.Tensor """ pass @@ -83,11 +82,4 @@ def loss(self): :return: The loss function to be minimized. :rtype: torch.nn.Module """ - return self._loss_fn - - @property - def time_weighting(self): - """ - The time weighting strategy. - """ - return self._time_weighting + return self._loss_fn \ No newline at end of file